diff --git a/docs/examples/dimensions.ipynb b/docs/examples/dimensions.ipynb index 146146c..bb55136 100644 --- a/docs/examples/dimensions.ipynb +++ b/docs/examples/dimensions.ipynb @@ -134,7 +134,7 @@ "| line | x, color, line_dash, facet_col, facet_row, animation_frame |\n", "| scatter | x, color, symbol, facet_col, facet_row, animation_frame |\n", "| bar | x, color, facet_col, facet_row, animation_frame |\n", - "| imshow | x, y, facet_col, facet_row, animation_frame |" + "| imshow | y, x, facet_col, facet_row, animation_frame |" ] }, { @@ -186,6 +186,16 @@ "xpx(data_3d).line(x=\"year\", facet_col=\"metric\", facet_row=\"country\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Heatmap grids: imshow supports facet_col and facet_row (requires plotly>=6.7)\n", + "xpx(data_3d).imshow(facet_row=\"metric\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 23685d7..2dff302 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -9,7 +9,8 @@ import xarray as xr import xarray_plotly # noqa: F401 - registers accessor -from xarray_plotly import xpx +from xarray_plotly import plotting, xpx +from xarray_plotly.plotting import _imshow_supports_facet_row class TestXpxFunction: @@ -403,6 +404,107 @@ def test_imshow_animation_consistent_bounds(self) -> None: assert coloraxis.cmax == 70.0 +requires_imshow_facet_row = pytest.mark.skipif( + not _imshow_supports_facet_row(), + reason="facet_row in px.imshow requires plotly>=6.7.0", +) + + +class TestImshowFaceting: + """Tests for imshow facet_col and facet_row.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + """Set up test data.""" + self.da_3d = xr.DataArray( + np.random.rand(4, 5, 3), + dims=["lat", "lon", "scenario"], + coords={"scenario": ["a", "b", "c"]}, + name="temperature", + ) + self.da_4d = xr.DataArray( + np.random.rand(4, 5, 2, 3), + dims=["lat", "lon", "scenario", "year"], + coords={"scenario": ["low", "high"], "year": [2020, 2021, 2022]}, + name="temperature", + ) + + def test_imshow_facet_col(self) -> None: + """Test imshow with facet_col creates one subplot per value.""" + fig = self.da_3d.plotly.imshow() + assert len(fig.data) == 3 + xaxes = [k for k in fig.layout if k.startswith("xaxis")] + assert len(xaxes) == 3 + + @requires_imshow_facet_row + def test_imshow_facet_row_explicit(self) -> None: + """Test imshow with explicit facet_row creates one subplot row per value.""" + fig = self.da_3d.plotly.imshow(facet_col=None, facet_row="scenario") + assert len(fig.data) == 3 + yaxes = [k for k in fig.layout if k.startswith("yaxis")] + assert len(yaxes) == 3 + annotations = {a.text for a in fig.layout.annotations} + assert annotations == {"scenario=a", "scenario=b", "scenario=c"} + + @requires_imshow_facet_row + def test_imshow_facet_row_auto_4d(self) -> None: + """Test that a 4D array auto-assigns facet_col and facet_row.""" + fig = self.da_4d.plotly.imshow() + # 2 facet columns (scenario) x 3 facet rows (year) + assert len(fig.data) == 6 + annotations = {a.text for a in fig.layout.annotations} + assert annotations == { + "scenario=low", + "scenario=high", + "year=2020", + "year=2021", + "year=2022", + } + + @requires_imshow_facet_row + def test_imshow_facet_grid_consistent_bounds(self) -> None: + """Test that facet grid subplots share global color bounds.""" + da = xr.DataArray( + np.arange(24, dtype=float).reshape(2, 2, 2, 3), + dims=["y", "x", "scenario", "year"], + ) + fig = da.plotly.imshow() + coloraxis = fig.layout.coloraxis + assert coloraxis.cmin == 0.0 + assert coloraxis.cmax == 23.0 + + @requires_imshow_facet_row + def test_imshow_facet_grid_with_animation(self) -> None: + """Test imshow with facet_col, facet_row, and animation_frame together.""" + da = xr.DataArray( + np.random.rand(4, 5, 2, 3, 6), + dims=["lat", "lon", "scenario", "year", "time"], + name="temperature", + ) + fig = da.plotly.imshow() + assert len(fig.data) == 6 + assert len(fig.frames) == 6 + + def test_imshow_explicit_facet_row_unsupported_error( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Test informative error when facet_row is requested on old plotly.""" + monkeypatch.setattr(plotting, "_imshow_supports_facet_row", lambda: False) + with pytest.raises(ValueError, match=r"facet_row for imshow requires plotly>=6\.7\.0"): + self.da_3d.plotly.imshow(facet_col=None, facet_row="scenario") + + def test_imshow_auto_skips_facet_row_on_old_plotly( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Test that auto-assignment skips facet_row on old plotly (4th dim animates).""" + monkeypatch.setattr(plotting, "_imshow_supports_facet_row", lambda: False) + fig = self.da_4d.plotly.imshow() + # year (4th dim) falls through to animation_frame instead of facet_row + assert len(fig.frames) == 3 + # only the facet_col (scenario) produces subplots + assert len(fig.data) == 2 + + class TestColorsParameter: """Tests for the unified colors parameter.""" diff --git a/tests/test_common.py b/tests/test_common.py index 1873ec0..7d7b4fc 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -20,6 +20,22 @@ def test_auto_assignment_imshow(self) -> None: slots = assign_slots(["lat", "lon"], "imshow") assert slots == {"y": "lat", "x": "lon"} + def test_auto_assignment_imshow_4d(self) -> None: + """Test that the fourth dimension fills facet_row for imshow.""" + slots = assign_slots(["lat", "lon", "scenario", "year"], "imshow") + assert slots == {"y": "lat", "x": "lon", "facet_col": "scenario", "facet_row": "year"} + + def test_auto_assignment_imshow_5d(self) -> None: + """Test that the fifth dimension fills animation_frame for imshow.""" + slots = assign_slots(["lat", "lon", "scenario", "year", "time"], "imshow") + assert slots == { + "y": "lat", + "x": "lon", + "facet_col": "scenario", + "facet_row": "year", + "animation_frame": "time", + } + def test_auto_assignment_scatter(self) -> None: """Test automatic positional assignment for scatter plots.""" slots = assign_slots(["x_dim", "color_dim"], "scatter") @@ -124,5 +140,6 @@ def test_imshow_slot_order(self) -> None: "y", "x", "facet_col", + "facet_row", "animation_frame", ) diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index eb3ddf0..3b88a94 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -302,6 +302,7 @@ def imshow( x: SlotValue = auto, y: SlotValue = auto, facet_col: SlotValue = auto, + facet_row: SlotValue = auto, animation_frame: SlotValue = auto, robust: bool = False, colors: Colors = None, @@ -309,7 +310,7 @@ def imshow( ) -> go.Figure: """Create an interactive heatmap image. - Slot order: y (rows) -> x (columns) -> facet_col -> animation_frame + Slot order: y (rows) -> x (columns) -> facet_col -> facet_row -> animation_frame Note: **Difference from px.imshow**: Color bounds are computed from the @@ -320,7 +321,10 @@ def imshow( x: Dimension for x-axis (columns). Default: second dimension. y: Dimension for y-axis (rows). Default: first dimension. facet_col: Dimension for subplot columns. Default: third dimension. - animation_frame: Dimension for animation. Default: fourth dimension. + facet_row: Dimension for subplot rows. Default: fourth dimension. + Requires plotly>=6.7.0; on older versions this slot is skipped + during auto-assignment. + animation_frame: Dimension for animation. Default: fifth dimension. robust: If True, use 2nd/98th percentiles for color bounds (handles outliers). colors: Color scale name (e.g., "Viridis", "RdBu"). See module docs. **px_kwargs: Additional arguments passed to `plotly.express.imshow()`. @@ -334,6 +338,7 @@ def imshow( x=x, y=y, facet_col=facet_col, + facet_row=facet_row, animation_frame=animation_frame, robust=robust, colors=colors, diff --git a/xarray_plotly/config.py b/xarray_plotly/config.py index d9d8655..ea65b77 100644 --- a/xarray_plotly/config.py +++ b/xarray_plotly/config.py @@ -43,7 +43,7 @@ "facet_row", "animation_frame", ), - "imshow": ("y", "x", "facet_col", "animation_frame"), + "imshow": ("y", "x", "facet_col", "facet_row", "animation_frame"), "box": ("x", "color", "facet_col", "facet_row", "animation_frame"), "pie": ("names", "facet_col", "facet_row"), } diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index a45cbd5..9ae2ec0 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -4,6 +4,7 @@ from __future__ import annotations +import inspect import warnings from typing import TYPE_CHECKING, Any @@ -614,12 +615,21 @@ def scatter( ) +def _imshow_supports_facet_row() -> bool: + """Check whether the installed plotly version supports facet_row in px.imshow. + + Support was added in plotly 6.7.0. + """ + return "facet_row" in inspect.signature(px.imshow).parameters + + def imshow( darray: DataArray, *, x: SlotValue = auto, y: SlotValue = auto, facet_col: SlotValue = auto, + facet_row: SlotValue = auto, animation_frame: SlotValue = auto, robust: bool = False, colors: Colors = None, @@ -629,7 +639,7 @@ def imshow( Create an interactive heatmap from a DataArray. Both x and y are dimensions. Dimensions fill slots in order: - y (rows) -> x (columns) -> facet_col -> animation_frame + y (rows) -> x (columns) -> facet_col -> facet_row -> animation_frame .. note:: **Difference from plotly.express.imshow**: By default, color bounds @@ -649,8 +659,14 @@ def imshow( Dimension for y-axis (rows). Default: first dimension. facet_col Dimension for subplot columns. Default: third dimension. + facet_row + Dimension for subplot rows. Default: fourth dimension. + Requires plotly>=6.7.0; on older versions this slot is skipped + during auto-assignment (the fourth dimension animates instead). + Note: ``facet_col_wrap`` is ignored by plotly when ``facet_row`` + is set. animation_frame - Dimension for animation. Default: fourth dimension. + Dimension for animation. Default: fifth dimension. robust If True, compute color bounds using 2nd and 98th percentiles for robustness against outliers. Default: False (uses min/max). @@ -668,18 +684,36 @@ def imshow( plotly.graph_objects.Figure """ px_kwargs = resolve_colors(colors, px_kwargs) + + # On plotly < 6.7.0, px.imshow has no facet_row: skip auto-assignment so + # dimensions fall through to animation_frame instead. + if facet_row is auto and not _imshow_supports_facet_row(): + facet_row = None + slots = assign_slots( list(darray.dims), "imshow", y=y, x=x, facet_col=facet_col, + facet_row=facet_row, animation_frame=animation_frame, ) - # Transpose to: y (rows), x (cols), facet_col, animation_frame + facet_row_kwargs: dict[str, Any] = {} + if slots.get("facet_row") is not None: + if not _imshow_supports_facet_row(): + import plotly + + msg = f"facet_row for imshow requires plotly>=6.7.0 (installed: {plotly.__version__})." + raise ValueError(msg) + facet_row_kwargs["facet_row"] = slots["facet_row"] + + # Transpose to: y (rows), x (cols), facet_col, facet_row, animation_frame transpose_order = [ - slots[k] for k in ("y", "x", "facet_col", "animation_frame") if slots.get(k) is not None + slots[k] + for k in ("y", "x", "facet_col", "facet_row", "animation_frame") + if slots.get(k) is not None ] plot_data = darray.transpose(*transpose_order) if transpose_order else darray @@ -701,6 +735,7 @@ def imshow( plot_data, facet_col=slots.get("facet_col"), animation_frame=slots.get("animation_frame"), + **facet_row_kwargs, **px_kwargs, )