diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 9e2465a7..a503592e 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -117,10 +117,19 @@ def _reparse_points( df: pd.DataFrame, transformation: Any, coordinate_system: str, + color_column: str | None = None, ) -> None: - """Re-register a points DataFrame in *sdata_filt* with its transformation.""" + """Re-register a points DataFrame in *sdata_filt* with its transformation. + + ``PointsModel.parse`` silently drops columns whose names collide with + reserved coordinate axes (currently only ``"z"``). When ``color_column`` + names such a column, re-attach it so downstream color lookup can find it. + """ dd_frame = dask.dataframe.from_pandas(df, npartitions=1) - sdata_filt.points[element] = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"}) + parsed = PointsModel.parse(dd_frame, coordinates={"x": "x", "y": "y"}) + if color_column is not None and color_column in df.columns and color_column not in parsed.columns: + parsed[color_column] = dd_frame[color_column] + sdata_filt.points[element] = parsed set_transformation( element=sdata_filt.points[element], transformation=transformation, @@ -820,7 +829,7 @@ def _render_points( # Convert back to dask dataframe to modify sdata transformation_in_cs = sdata_filt.points[element].attrs["transform"][coordinate_system] - _reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system) + _reparse_points(sdata_filt, element, points_for_model, transformation_in_cs, coordinate_system, col_for_color) if col_for_color is not None: assert isinstance(col_for_color, str) @@ -877,6 +886,7 @@ def _render_points( points_pd_with_color, transformation_in_cs, coordinate_system, + col_for_color, ) _warn_groups_ignored_continuous(groups, color_source_vector, col_for_color) @@ -897,7 +907,7 @@ def _render_points( # filter the materialized points, adata, and re-register in sdata_filt points = points[keep].reset_index(drop=True) adata = adata[keep] - _reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system) + _reparse_points(sdata_filt, element, points, transformation_in_cs, coordinate_system, col_for_color) # color_source_vector is None when the values aren't categorical if color_source_vector is None and render_params.transfunc is not None: diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index dcc4267c..bbb224e2 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -1006,6 +1006,40 @@ def test_no_table_fallback_warning_for_element_column(caplog): plt.close("all") +def test_render_points_color_by_z_data_column(): + # regression test for #615 + pts = PointsModel.parse( + pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0], "z": [0.1, 0.5, 0.9]}), + ) + assert "z" in pts.columns + sdata = SpatialData(points={"p": pts}) + fig, ax = plt.subplots() + try: + sdata.pl.render_points("p", color="z").pl.show(ax=ax) + finally: + plt.close(fig) + + +def test_render_points_color_by_z_with_extra_columns(): + # regression test for #615 + pts = PointsModel.parse( + pd.DataFrame( + { + "x": [1.0, 2.0, 3.0], + "y": [1.0, 2.0, 3.0], + "z": [0.1, 0.5, 0.9], + "score": [0.0, 0.5, 1.0], + } + ), + ) + sdata = SpatialData(points={"p": pts}) + fig, ax = plt.subplots() + try: + sdata.pl.render_points("p", color="score").pl.show(ax=ax) + finally: + plt.close(fig) + + def test_render_points_disjoint_instance_ids_clear_error(): # regression test for #603: disjoint instance_id values must raise a clear ValueError points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))