Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,17 +443,17 @@ def test_imshow_facet_row_explicit(self) -> None:
assert len(fig.data) == 3
yaxes = [k for k in fig.layout if k.startswith("yaxis")]
assert len(yaxes) == 3
annotations = {a.text for a in fig.layout.annotations}
assert annotations == {"scenario=a", "scenario=b", "scenario=c"}
facet_titles = {a.text for a in fig.layout.annotations if "=" in (a.text or "")}
assert facet_titles == {"scenario=a", "scenario=b", "scenario=c"}

@requires_imshow_facet_row
def test_imshow_facet_row_auto_4d(self) -> None:
"""Test that a 4D array auto-assigns facet_col and facet_row."""
fig = self.da_4d.plotly.imshow()
# 2 facet columns (scenario) x 3 facet rows (year)
assert len(fig.data) == 6
annotations = {a.text for a in fig.layout.annotations}
assert annotations == {
facet_titles = {a.text for a in fig.layout.annotations if "=" in (a.text or "")}
assert facet_titles == {
"scenario=low",
"scenario=high",
"year=2020",
Expand Down
157 changes: 157 additions & 0 deletions tests/test_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xarray_plotly import (
add_secondary_y,
overlay,
share_axis_labels,
simplify_facet_titles,
subplots,
xpx,
Expand Down Expand Up @@ -989,3 +990,159 @@ def test_kwarg_value_strips_prefix(self) -> None:
def test_kwarg_invalid_raises(self) -> None:
with pytest.raises(ValueError, match="facet_titles must be"):
xpx(self.da).line(facet_col="country", facet_titles="bogus") # type: ignore[arg-type]


def _axis_titles(fig: go.Figure, letter: str) -> list[str]:
"""Collect non-empty axis title texts for one axis direction."""
return [
fig.layout[k].title.text
for k in fig.layout
if k.startswith(f"{letter}axis") and fig.layout[k].title.text
]


def _shared_label_annotations(fig: go.Figure) -> list[go.layout.Annotation]:
"""Collect annotations positioned like shared axis labels (centered bottom/left)."""
return [
a
for a in fig.layout.annotations
if a.xref == "paper" and a.yref == "paper" and (a.x, a.y) in {(0.5, 0), (0, 0.5)}
]


class TestShareAxisLabels:
"""Tests for the share_axis_labels helper and the `shared_axis_labels` kwarg."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
self.da = xr.DataArray(
np.random.rand(10, 2, 3),
dims=["time", "scenario", "country"],
coords={"scenario": ["low", "high"], "country": ["US", "CN", "BR"]},
name="value",
)

def test_matches_plotly_builtin_spec(self) -> None:
"""Shared labels must be identical to make_subplots(x_title=, y_title=) output."""
import json

from plotly.subplots import make_subplots

def canonical(annotations: list[go.layout.Annotation]) -> set[str]:
return {json.dumps(a.to_plotly_json(), sort_keys=True) for a in annotations}

reference = make_subplots(rows=2, cols=2, x_title="value", y_title="value")
ref_specs = canonical(list(reference.layout.annotations))

fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country")
# Rename both labels to "value" so the comparison covers all other fields
for ann in _shared_label_annotations(fig):
ann.text = "value"
assert canonical(_shared_label_annotations(fig)) == ref_specs

def test_helper_facet_col_shares_x_label(self) -> None:
fig = xpx(self.da.isel(scenario=0, drop=True)).line(
color=None, facet_col="country", shared_axis_labels=False
)
assert len(_axis_titles(fig, "x")) == 3 # repeated per column

share_axis_labels(fig)

assert _axis_titles(fig, "x") == []
labels = _shared_label_annotations(fig)
assert [a.text for a in labels] == ["time"]
# y title was never repeated, so it stays a regular axis title
assert _axis_titles(fig, "y") == ["value"]

def test_helper_facet_row_shares_y_label(self) -> None:
fig = xpx(self.da.isel(scenario=0, drop=True)).line(
color=None, facet_row="country", shared_axis_labels=False
)
assert len(_axis_titles(fig, "y")) == 3 # repeated per row

share_axis_labels(fig)

assert _axis_titles(fig, "y") == []
labels = _shared_label_annotations(fig)
assert [a.text for a in labels] == ["value"]
assert _axis_titles(fig, "x") == ["time"]

def test_helper_facet_grid_shares_both(self) -> None:
fig = xpx(self.da).line(
color=None, facet_col="scenario", facet_row="country", shared_axis_labels=False
)
share_axis_labels(fig)

assert _axis_titles(fig, "x") == []
assert _axis_titles(fig, "y") == []
labels = {a.text for a in _shared_label_annotations(fig)}
assert labels == {"time", "value"}

def test_helper_no_facets_is_noop(self) -> None:
fig = xpx(self.da.isel(scenario=0, country=0, drop=True)).line()
before = fig.to_plotly_json()
share_axis_labels(fig)
assert fig.to_plotly_json() == before

def test_helper_idempotent(self) -> None:
fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country")
share_axis_labels(fig)
n_annotations = len(fig.layout.annotations)
share_axis_labels(fig)
assert len(fig.layout.annotations) == n_annotations

def test_helper_preserves_facet_titles(self) -> None:
fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country")
share_axis_labels(fig)
texts = [a.text for a in fig.layout.annotations]
assert "scenario=low" in texts
assert "scenario=high" in texts

def test_helper_different_titles_untouched(self) -> None:
"""Figures whose subplots have different axis titles are not collapsed."""
da1 = xr.DataArray(np.random.rand(5), dims=["time"], name="temperature")
da2 = xr.DataArray(np.random.rand(5), dims=["distance"], name="pressure")
fig = subplots(xpx(da1).line(), xpx(da2).line(), cols=2)
before_x = _axis_titles(fig, "x")
before_y = _axis_titles(fig, "y")

share_axis_labels(fig)

assert _axis_titles(fig, "x") == before_x
assert _axis_titles(fig, "y") == before_y
assert _shared_label_annotations(fig) == []

def test_helper_secondary_y_untouched(self) -> None:
"""Identical titles on a primary and overlaying secondary y axis stay separate."""
da = xr.DataArray(np.random.rand(5), dims=["time"], name="value")
base = xpx(da).line()
secondary = xpx(da).line()
fig = add_secondary_y(base, secondary)
before_y = _axis_titles(fig, "y")

share_axis_labels(fig)

assert _axis_titles(fig, "y") == before_y

def test_kwarg_shares_by_default(self) -> None:
fig = xpx(self.da).line(color=None, facet_col="scenario", facet_row="country")
assert _axis_titles(fig, "x") == []
assert _axis_titles(fig, "y") == []
assert {a.text for a in _shared_label_annotations(fig)} == {"time", "value"}

def test_kwarg_opt_out(self) -> None:
fig = xpx(self.da).line(
color=None, facet_col="scenario", facet_row="country", shared_axis_labels=False
)
assert len(_axis_titles(fig, "x")) > 1
assert len(_axis_titles(fig, "y")) > 1
assert _shared_label_annotations(fig) == []

def test_kwarg_on_imshow(self) -> None:
da = xr.DataArray(
np.random.rand(4, 5, 2, 3),
dims=["lat", "lon", "scenario", "year"],
name="temperature",
)
fig = da.plotly.imshow()
assert {a.text for a in _shared_label_annotations(fig)} == {"lat", "lon"}
2 changes: 2 additions & 0 deletions xarray_plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from xarray_plotly.figures import (
add_secondary_y,
overlay,
share_axis_labels,
simplify_facet_titles,
subplots,
update_traces,
Expand All @@ -68,6 +69,7 @@
"auto",
"config",
"overlay",
"share_axis_labels",
"simplify_facet_titles",
"subplots",
"update_traces",
Expand Down
Loading