From e720f47bb495d137ffae99fdcfbda68ca49f2ca6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 3 Jun 2026 10:29:52 +0200 Subject: [PATCH] feat(plotting): shared axis labels for faceted plots Plotly Express repeats the x-axis title under every facet column and the y-axis title beside every facet row, and exposes no built-in way to fix this on an existing figure (make_subplots' x_title/y_title only work at figure creation and are not reachable from px). Add a public share_axis_labels(fig) helper that collapses repeated, identical axis titles into one centered label per axis, using annotation specs identical to plotly's built-in shared subplot titles (verified by a sync test against make_subplots output). Titles are only collapsed when actually repeated, so combined figures with differing subplot titles and secondary-y figures pass through unchanged. All plotting/accessor methods gain a shared_axis_labels kwarg, on by default; pass shared_axis_labels=False for plotly's stock behavior. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/test_accessor.py | 8 +- tests/test_figures.py | 157 ++++++++++++++++++++++++++++++++++++++ xarray_plotly/__init__.py | 2 + xarray_plotly/accessor.py | 60 +++++++++++++++ xarray_plotly/figures.py | 77 +++++++++++++++++++ xarray_plotly/plotting.py | 57 ++++++++++++-- 6 files changed, 349 insertions(+), 12 deletions(-) diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 2dff302..aa8f1a4 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -443,8 +443,8 @@ def test_imshow_facet_row_explicit(self) -> None: assert len(fig.data) == 3 yaxes = [k for k in fig.layout if k.startswith("yaxis")] assert len(yaxes) == 3 - annotations = {a.text for a in fig.layout.annotations} - assert annotations == {"scenario=a", "scenario=b", "scenario=c"} + facet_titles = {a.text for a in fig.layout.annotations if "=" in (a.text or "")} + assert facet_titles == {"scenario=a", "scenario=b", "scenario=c"} @requires_imshow_facet_row def test_imshow_facet_row_auto_4d(self) -> None: @@ -452,8 +452,8 @@ def test_imshow_facet_row_auto_4d(self) -> None: fig = self.da_4d.plotly.imshow() # 2 facet columns (scenario) x 3 facet rows (year) assert len(fig.data) == 6 - annotations = {a.text for a in fig.layout.annotations} - assert annotations == { + facet_titles = {a.text for a in fig.layout.annotations if "=" in (a.text or "")} + assert facet_titles == { "scenario=low", "scenario=high", "year=2020", diff --git a/tests/test_figures.py b/tests/test_figures.py index 7789daa..b9124d6 100644 --- a/tests/test_figures.py +++ b/tests/test_figures.py @@ -12,6 +12,7 @@ from xarray_plotly import ( add_secondary_y, overlay, + share_axis_labels, simplify_facet_titles, subplots, xpx, @@ -989,3 +990,159 @@ def test_kwarg_value_strips_prefix(self) -> None: def test_kwarg_invalid_raises(self) -> None: with pytest.raises(ValueError, match="facet_titles must be"): xpx(self.da).line(facet_col="country", facet_titles="bogus") # type: ignore[arg-type] + + +def _axis_titles(fig: go.Figure, letter: str) -> list[str]: + """Collect non-empty axis title texts for one axis direction.""" + return [ + fig.layout[k].title.text + for k in fig.layout + if k.startswith(f"{letter}axis") and fig.layout[k].title.text + ] + + +def _shared_label_annotations(fig: go.Figure) -> list[go.layout.Annotation]: + """Collect annotations positioned like shared axis labels (centered bottom/left).""" + return [ + a + for a in fig.layout.annotations + if a.xref == "paper" and a.yref == "paper" and (a.x, a.y) in {(0.5, 0), (0, 0.5)} + ] + + +class TestShareAxisLabels: + """Tests for the share_axis_labels helper and the `shared_axis_labels` kwarg.""" + + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.da = xr.DataArray( + np.random.rand(10, 2, 3), + dims=["time", "scenario", "country"], + coords={"scenario": ["low", "high"], "country": ["US", "CN", "BR"]}, + name="value", + ) + + def test_matches_plotly_builtin_spec(self) -> None: + """Shared labels must be identical to make_subplots(x_title=, y_title=) output.""" + import json + + from plotly.subplots import make_subplots + + def canonical(annotations: list[go.layout.Annotation]) -> set[str]: + return {json.dumps(a.to_plotly_json(), sort_keys=True) for a in annotations} + + reference = make_subplots(rows=2, cols=2, x_title="value", y_title="value") + ref_specs = canonical(list(reference.layout.annotations)) + + fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country") + # Rename both labels to "value" so the comparison covers all other fields + for ann in _shared_label_annotations(fig): + ann.text = "value" + assert canonical(_shared_label_annotations(fig)) == ref_specs + + def test_helper_facet_col_shares_x_label(self) -> None: + fig = xpx(self.da.isel(scenario=0, drop=True)).line( + color=None, facet_col="country", shared_axis_labels=False + ) + assert len(_axis_titles(fig, "x")) == 3 # repeated per column + + share_axis_labels(fig) + + assert _axis_titles(fig, "x") == [] + labels = _shared_label_annotations(fig) + assert [a.text for a in labels] == ["time"] + # y title was never repeated, so it stays a regular axis title + assert _axis_titles(fig, "y") == ["value"] + + def test_helper_facet_row_shares_y_label(self) -> None: + fig = xpx(self.da.isel(scenario=0, drop=True)).line( + color=None, facet_row="country", shared_axis_labels=False + ) + assert len(_axis_titles(fig, "y")) == 3 # repeated per row + + share_axis_labels(fig) + + assert _axis_titles(fig, "y") == [] + labels = _shared_label_annotations(fig) + assert [a.text for a in labels] == ["value"] + assert _axis_titles(fig, "x") == ["time"] + + def test_helper_facet_grid_shares_both(self) -> None: + fig = xpx(self.da).line( + color=None, facet_col="scenario", facet_row="country", shared_axis_labels=False + ) + share_axis_labels(fig) + + assert _axis_titles(fig, "x") == [] + assert _axis_titles(fig, "y") == [] + labels = {a.text for a in _shared_label_annotations(fig)} + assert labels == {"time", "value"} + + def test_helper_no_facets_is_noop(self) -> None: + fig = xpx(self.da.isel(scenario=0, country=0, drop=True)).line() + before = fig.to_plotly_json() + share_axis_labels(fig) + assert fig.to_plotly_json() == before + + def test_helper_idempotent(self) -> None: + fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country") + share_axis_labels(fig) + n_annotations = len(fig.layout.annotations) + share_axis_labels(fig) + assert len(fig.layout.annotations) == n_annotations + + def test_helper_preserves_facet_titles(self) -> None: + fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country") + share_axis_labels(fig) + texts = [a.text for a in fig.layout.annotations] + assert "scenario=low" in texts + assert "scenario=high" in texts + + def test_helper_different_titles_untouched(self) -> None: + """Figures whose subplots have different axis titles are not collapsed.""" + da1 = xr.DataArray(np.random.rand(5), dims=["time"], name="temperature") + da2 = xr.DataArray(np.random.rand(5), dims=["distance"], name="pressure") + fig = subplots(xpx(da1).line(), xpx(da2).line(), cols=2) + before_x = _axis_titles(fig, "x") + before_y = _axis_titles(fig, "y") + + share_axis_labels(fig) + + assert _axis_titles(fig, "x") == before_x + assert _axis_titles(fig, "y") == before_y + assert _shared_label_annotations(fig) == [] + + def test_helper_secondary_y_untouched(self) -> None: + """Identical titles on a primary and overlaying secondary y axis stay separate.""" + da = xr.DataArray(np.random.rand(5), dims=["time"], name="value") + base = xpx(da).line() + secondary = xpx(da).line() + fig = add_secondary_y(base, secondary) + before_y = _axis_titles(fig, "y") + + share_axis_labels(fig) + + assert _axis_titles(fig, "y") == before_y + + def test_kwarg_shares_by_default(self) -> None: + fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country") + assert _axis_titles(fig, "x") == [] + assert _axis_titles(fig, "y") == [] + assert {a.text for a in _shared_label_annotations(fig)} == {"time", "value"} + + def test_kwarg_opt_out(self) -> None: + fig = xpx(self.da).line( + color=None, facet_col="scenario", facet_row="country", shared_axis_labels=False + ) + assert len(_axis_titles(fig, "x")) > 1 + assert len(_axis_titles(fig, "y")) > 1 + assert _shared_label_annotations(fig) == [] + + def test_kwarg_on_imshow(self) -> None: + da = xr.DataArray( + np.random.rand(4, 5, 2, 3), + dims=["lat", "lon", "scenario", "year"], + name="temperature", + ) + fig = da.plotly.imshow() + assert {a.text for a in _shared_label_annotations(fig)} == {"lat", "lon"} diff --git a/xarray_plotly/__init__.py b/xarray_plotly/__init__.py index 49250e1..38feb37 100644 --- a/xarray_plotly/__init__.py +++ b/xarray_plotly/__init__.py @@ -56,6 +56,7 @@ from xarray_plotly.figures import ( add_secondary_y, overlay, + share_axis_labels, simplify_facet_titles, subplots, update_traces, @@ -68,6 +69,7 @@ "auto", "config", "overlay", + "share_axis_labels", "simplify_facet_titles", "subplots", "update_traces", diff --git a/xarray_plotly/accessor.py b/xarray_plotly/accessor.py index 6cdcd37..ec4fa50 100644 --- a/xarray_plotly/accessor.py +++ b/xarray_plotly/accessor.py @@ -55,6 +55,7 @@ def line( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive line plot. @@ -70,6 +71,8 @@ def line( facet_row: Dimension for subplot rows. Default: sixth dimension. animation_frame: Dimension for animation. Default: seventh dimension. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.line()`. Returns: @@ -86,6 +89,7 @@ def line( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -100,6 +104,7 @@ def bar( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive bar chart. @@ -114,6 +119,8 @@ def bar( facet_row: Dimension for subplot rows. Default: fifth dimension. animation_frame: Dimension for animation. Default: sixth dimension. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.bar()`. Returns: @@ -129,6 +136,7 @@ def bar( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -143,6 +151,7 @@ def area( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive stacked area chart. @@ -157,6 +166,8 @@ def area( facet_row: Dimension for subplot rows. Default: fifth dimension. animation_frame: Dimension for animation. Default: sixth dimension. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -172,6 +183,7 @@ def area( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -185,6 +197,7 @@ def fast_bar( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create a bar-like chart using stacked areas for better performance. @@ -198,6 +211,8 @@ def fast_bar( facet_row: Dimension for subplot rows. Default: fourth dimension. animation_frame: Dimension for animation. Default: fifth dimension. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -212,6 +227,7 @@ def fast_bar( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -227,6 +243,7 @@ def scatter( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive scatter plot. @@ -246,6 +263,8 @@ def scatter( facet_row: Dimension for subplot rows. Default: fifth dimension. animation_frame: Dimension for animation. Default: sixth dimension. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.scatter()`. Returns: @@ -262,6 +281,7 @@ def scatter( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -275,6 +295,7 @@ def box( animation_frame: SlotValue = None, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive box plot. @@ -291,6 +312,8 @@ def box( facet_row: Dimension for subplot rows. Default: None (aggregated). animation_frame: Dimension for animation. Default: None (aggregated). colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.box()`. Returns: @@ -305,6 +328,7 @@ def box( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -319,6 +343,7 @@ def imshow( robust: bool = False, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive heatmap image. @@ -340,6 +365,8 @@ def imshow( animation_frame: Dimension for animation. Default: fifth dimension. robust: If True, use 2nd/98th percentiles for color bounds (handles outliers). colors: Color scale name (e.g., "Viridis", "RdBu"). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.imshow()`. Use `zmin` and `zmax` to manually set color scale bounds. @@ -356,6 +383,7 @@ def imshow( robust=robust, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -368,6 +396,7 @@ def pie( facet_row: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive pie chart. @@ -380,6 +409,8 @@ def pie( facet_col: Dimension for subplot columns. Default: second dimension. facet_row: Dimension for subplot rows. Default: third dimension. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.pie()`. Returns: @@ -393,6 +424,7 @@ def pie( facet_row=facet_row, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -474,6 +506,7 @@ def line( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive line plot. @@ -488,6 +521,8 @@ def line( facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.line()`. Returns: @@ -505,6 +540,7 @@ def line( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -520,6 +556,7 @@ def bar( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive bar chart. @@ -533,6 +570,8 @@ def bar( facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.bar()`. Returns: @@ -549,6 +588,7 @@ def bar( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -564,6 +604,7 @@ def area( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive stacked area chart. @@ -577,6 +618,8 @@ def area( facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -593,6 +636,7 @@ def area( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -607,6 +651,7 @@ def fast_bar( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create a bar-like chart using stacked areas for better performance. @@ -619,6 +664,8 @@ def fast_bar( facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.area()`. Returns: @@ -634,6 +681,7 @@ def fast_bar( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -650,6 +698,7 @@ def scatter( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive scatter plot. @@ -664,6 +713,8 @@ def scatter( facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.scatter()`. Returns: @@ -681,6 +732,7 @@ def scatter( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -695,6 +747,7 @@ def box( animation_frame: SlotValue = None, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive box plot. @@ -707,6 +760,8 @@ def box( facet_row: Dimension for subplot rows. animation_frame: Dimension for animation. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.box()`. Returns: @@ -722,6 +777,7 @@ def box( animation_frame=animation_frame, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) @@ -735,6 +791,7 @@ def pie( facet_row: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """Create an interactive pie chart. @@ -746,6 +803,8 @@ def pie( facet_col: Dimension for subplot columns. facet_row: Dimension for subplot rows. colors: Color specification (scale name, list, or dict). See module docs. + shared_axis_labels: If True (default), repeated axis titles on faceted + plots are replaced with a single shared, centered label per axis. **px_kwargs: Additional arguments passed to `plotly.express.pie()`. Returns: @@ -760,5 +819,6 @@ def pie( facet_row=facet_row, colors=colors, facet_titles=facet_titles, + shared_axis_labels=shared_axis_labels, **px_kwargs, ) diff --git a/xarray_plotly/figures.py b/xarray_plotly/figures.py index 2d9b231..0c9f9d7 100644 --- a/xarray_plotly/figures.py +++ b/xarray_plotly/figures.py @@ -971,3 +971,80 @@ def simplify_facet_titles( if text and _FACET_TITLE_PREFIX_RE.match(text): ann.text = text.split("=", 1)[1] return fig + + +# Matches cartesian axis layout keys like "xaxis", "xaxis2", "yaxis12". +_AXIS_KEY_RE = re.compile(r"([xy])axis\d*$") + +# Annotation specs identical to plotly's built-in shared subplot titles, +# i.e. what `make_subplots(x_title=..., y_title=...)` produces. Kept in +# sync with plotly via a test against make_subplots output. +_SHARED_LABEL_SPECS: dict[str, dict[str, Any]] = { + "x": { + "x": 0.5, + "y": 0, + "xref": "paper", + "yref": "paper", + "xanchor": "center", + "yanchor": "top", + "yshift": -30, + "showarrow": False, + "font": {"size": 16}, + }, + "y": { + "x": 0, + "y": 0.5, + "xref": "paper", + "yref": "paper", + "xanchor": "right", + "yanchor": "middle", + "xshift": -40, + "textangle": -90, + "showarrow": False, + "font": {"size": 16}, + }, +} + + +def share_axis_labels(fig: go.Figure) -> go.Figure: + """Replace repeated facet axis titles with a single shared label per axis. + + Plotly Express repeats the x-axis title under every facet column and the + y-axis title beside every facet row. This helper removes the repeated + titles and adds one centered label per axis instead, styled exactly like + plotly's built-in shared titles (``make_subplots(x_title=..., y_title=...)``), + which Plotly Express does not expose for faceted figures. + + Titles are only collapsed when they are repeated and identical, so + figures without facets, figures combined from differently-labeled + subplots, and secondary-y figures are returned unchanged. + + Args: + fig: A Plotly figure (mutated in place). + + Returns: + The (possibly mutated) figure, for chaining. + + Example: + >>> import plotly.express as px + >>> from xarray_plotly import share_axis_labels + >>> fig = px.line(df, x="year", y="gdp", facet_col="country", facet_row="metric") + >>> share_axis_labels(fig) # one "year" below, one "gdp" at the left + """ + axes_by_letter: dict[str, list[Any]] = {"x": [], "y": []} + for key in fig.layout: + match = _AXIS_KEY_RE.match(key) + # Overlaying axes (secondary y) share their domain with the axis + # they overlay; their titles are independent, not facet repetition. + if match and not fig.layout[key].overlaying: + axes_by_letter[match.group(1)].append(fig.layout[key]) + + for letter, axes in axes_by_letter.items(): + titles = [axis.title.text for axis in axes if axis.title.text] + # Only collapse titles that are actually repeated and identical + if len(titles) < 2 or len(set(titles)) != 1: + continue + for axis in axes: + axis.title.text = None + fig.add_annotation(text=titles[0], **_SHARED_LABEL_SPECS[letter]) + return fig diff --git a/xarray_plotly/plotting.py b/xarray_plotly/plotting.py index aa14072..d52abce 100644 --- a/xarray_plotly/plotting.py +++ b/xarray_plotly/plotting.py @@ -26,6 +26,7 @@ ) from xarray_plotly.figures import ( _iter_all_traces, + share_axis_labels, simplify_facet_titles, ) @@ -34,6 +35,14 @@ from xarray import DataArray +def _finalize(fig: go.Figure, facet_titles: FacetTitlesMode, shared_axis_labels: bool) -> go.Figure: + """Apply facet post-processing shared by all plot functions.""" + simplify_facet_titles(fig, facet_titles) + if shared_axis_labels: + share_axis_labels(fig) + return fig + + def line( darray: DataArray, *, @@ -46,6 +55,7 @@ def line( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -79,6 +89,9 @@ def line( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.line()`. @@ -116,7 +129,7 @@ def line( labels=labels, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def bar( @@ -130,6 +143,7 @@ def bar( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -161,6 +175,9 @@ def bar( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.bar()`. @@ -196,7 +213,7 @@ def bar( labels=labels, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def _classify_trace_sign(y_values: npt.ArrayLike) -> str: @@ -293,6 +310,7 @@ def fast_bar( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -330,6 +348,9 @@ def fast_bar( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.area()`. @@ -367,7 +388,7 @@ def fast_bar( _style_traces_as_bars(fig) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def area( @@ -381,6 +402,7 @@ def area( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -412,6 +434,9 @@ def area( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.area()`. @@ -447,7 +472,7 @@ def area( labels=labels, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def box( @@ -460,6 +485,7 @@ def box( animation_frame: SlotValue = None, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -491,6 +517,9 @@ def box( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.box()`. @@ -525,7 +554,7 @@ def box( labels=labels, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def scatter( @@ -540,6 +569,7 @@ def scatter( animation_frame: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -578,6 +608,9 @@ def scatter( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.scatter()`. @@ -625,7 +658,7 @@ def scatter( labels=labels, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def _imshow_supports_facet_row() -> bool: @@ -647,6 +680,7 @@ def imshow( robust: bool = False, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -689,6 +723,9 @@ def imshow( continuous scale (e.g., "Viridis", "RdBu"). Lists and dicts are not applicable for heatmaps. Explicit color_continuous_scale in px_kwargs takes precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.imshow()`. Use `zmin` and `zmax` to manually set color scale bounds. @@ -752,7 +789,7 @@ def imshow( **facet_row_kwargs, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels) def pie( @@ -764,6 +801,7 @@ def pie( facet_row: SlotValue = auto, colors: Colors = None, facet_titles: FacetTitlesMode = "default", + shared_axis_labels: bool = True, **px_kwargs: Any, ) -> go.Figure: """ @@ -790,6 +828,9 @@ def pie( - A list of colors (e.g., ["red", "blue", "green"]) - A dict mapping values to colors (e.g., {"A": "red", "B": "blue"}) Explicit color_* kwargs in px_kwargs take precedence. + shared_axis_labels + If True (default), repeated axis titles on faceted plots are + replaced with a single shared, centered label per axis. **px_kwargs Additional arguments passed to `plotly.express.pie()`. @@ -823,4 +864,4 @@ def pie( labels=labels, **px_kwargs, ) - return simplify_facet_titles(fig, facet_titles) + return _finalize(fig, facet_titles, shared_axis_labels)