Skip to content

nan handling for continuous and categorical coloring #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
99 changes: 85 additions & 14 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,14 @@ def _render_shapes(
transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
# Render shapes with datashader
color_by_categorical = col_for_color is not None and color_source_vector is not None

aggregate_with_reduction = None
continuous_nan_shapes = None
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
if color_by_categorical:
agg = cvs.polygons(
transformed_element,
geometry="geometry",
agg=ds.by(col_for_color, ds.count()),
)
# add nan as a category so that shapes with nan value are colored in the nan color
transformed_element[col_for_color] = transformed_element[col_for_color].cat.add_categories("nan")
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count()))
else:
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean"
logger.info(
Expand All @@ -232,6 +232,13 @@ def _render_shapes(
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())

# nan shapes need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods)
transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()]
if len(transformed_element_nan_color) > 0:
continuous_nan_shapes = _datashader_aggregate_with_function(
"any", cvs, transformed_element_nan_color, None, "shapes"
)
else:
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count())
# render outlines if needed
Expand Down Expand Up @@ -297,6 +304,18 @@ def _render_shapes(
clip=norm.clip,
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes

if continuous_nan_shapes is not None:
# for coloring by continuous variable: render nan shapes separately
nan_color = render_params.cmap_params.na_color
if isinstance(nan_color, str) and nan_color.startswith("#") and len(nan_color) == 9:
nan_color = nan_color[:7]
continuous_nan_shapes = ds.tf.shade(
continuous_nan_shapes,
cmap=nan_color,
how="linear",
min_alpha=np.min([254, render_params.fill_alpha * 255]),
)

# shade outlines if needed
outline_color = render_params.outline_params.outline_color
if isinstance(outline_color, str) and outline_color.startswith("#") and len(outline_color) == 9:
Expand All @@ -314,6 +333,17 @@ def _render_shapes(
how="linear",
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes

if continuous_nan_shapes is not None:
# for coloring by continuous variable: render nan points separately
rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_shapes, factor, ax)
_ax_show_and_transform(
rgba_image_nan,
trans_data_nan,
ax,
zorder=render_params.zorder,
alpha=render_params.fill_alpha,
extent=x_ext + y_ext,
)
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_cax = _ax_show_and_transform(
rgba_image,
Expand Down Expand Up @@ -353,7 +383,7 @@ def _render_shapes(
_cax = _get_collection_shape(
shapes=shapes,
s=render_params.scale,
c=color_vector,
c=color_vector.copy(), # copy bc c is modified in _get_collection_shape
render_params=render_params,
rasterized=sc_settings._vector_friendly,
cmap=render_params.cmap_params.cmap,
Expand All @@ -373,8 +403,8 @@ def _render_shapes(
# If the user passed a Normalize object with vmin/vmax we'll use those,
# if not we'll use the min/max of the color_vector
_cax.set_clim(
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
vmin=render_params.cmap_params.norm.vmin or np.nanmin(color_vector),
vmax=render_params.cmap_params.norm.vmax or np.nanmax(color_vector),
)

if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
Expand Down Expand Up @@ -588,15 +618,28 @@ def _render_points(
# use datashader for the visualization of points
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)

# in case we are coloring by a column in table
if col_for_color is not None and col_for_color not in transformed_element.columns:
if color_source_vector is not None:
transformed_element = transformed_element.assign(col_for_color=pd.Series(color_source_vector))
else:
transformed_element = transformed_element.assign(col_for_color=pd.Series(color_vector))
transformed_element = transformed_element.rename(columns={"col_for_color": col_for_color})

color_by_categorical = col_for_color is not None and transformed_element[col_for_color].values.dtype in (
object,
"categorical",
)
if color_by_categorical and transformed_element[col_for_color].values.dtype == object:
transformed_element[col_for_color] = transformed_element[col_for_color].astype("category")

aggregate_with_reduction = None
continuous_nan_points = None
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
if color_by_categorical:
# add nan as category so that nan points are shown in the nan color
transformed_element[col_for_color] = transformed_element[col_for_color].cat.as_known()
transformed_element[col_for_color] = transformed_element[col_for_color].cat.add_categories("nan")
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
else:
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
Expand All @@ -613,6 +656,12 @@ def _render_points(
)
# save min and max values for drawing the colorbar
aggregate_with_reduction = (agg.min(), agg.max())
# nan points need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods)
transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()]
if len(transformed_element_nan_color) > 0:
continuous_nan_points = _datashader_aggregate_with_function(
"any", cvs, transformed_element_nan_color, None, "points"
)
else:
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())

Expand Down Expand Up @@ -640,12 +689,10 @@ def _render_points(
)

# remove alpha from color if it's hex
if color_key is not None and all(len(x) == 9 for x in color_key) and color_key[0][0] == "#":
color_key = [x[:-2] for x in color_key]
if isinstance(color_vector[0], str) and (
color_vector is not None and all(len(x) == 9 for x in color_vector) and color_vector[0][0] == "#"
):
color_vector = np.asarray([x[:-2] for x in color_vector])
if color_key is not None and color_key[0][0] == "#":
color_key = [_hex_no_alpha(x) for x in color_key]
if isinstance(color_vector[0], str) and (color_vector is not None and color_vector[0][0] == "#"):
color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector])

if color_by_categorical or col_for_color is None:
ds_result = _datashader_map_aggregate_to_color(
Expand Down Expand Up @@ -678,6 +725,29 @@ def _render_points(
min_alpha=np.min([254, render_params.alpha * 255]),
) # prevent min_alpha == 255, bc that led to fully colored test plots instead of just colored points/shapes

if continuous_nan_points is not None:
# for coloring by continuous variable: render nan points separately
nan_color = render_params.cmap_params.na_color
if isinstance(nan_color, str) and nan_color.startswith("#") and len(nan_color) == 9:
nan_color = nan_color[:7]
continuous_nan_points = ds.tf.spread(continuous_nan_points, px=px, how="max")
continuous_nan_points = ds.tf.shade(
continuous_nan_points,
cmap=nan_color,
how="linear",
)

if continuous_nan_points is not None:
# for coloring by continuous variable: render nan points separately
rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_points, factor, ax)
_ax_show_and_transform(
rgba_image_nan,
trans_data_nan,
ax,
zorder=render_params.zorder,
alpha=render_params.alpha,
extent=x_ext + y_ext,
)
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
_ax_show_and_transform(
rgba_image,
Expand Down Expand Up @@ -716,6 +786,7 @@ def _render_points(
alpha=render_params.alpha,
transform=trans_data,
zorder=render_params.zorder,
plotnonfinite=True, # nan points should be rendered as well
)
cax = ax.add_collection(_cax)
if update_parameters:
Expand Down
24 changes: 16 additions & 8 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,15 @@ def _get_collection_shape(
c = cmap(c)
else:
try:
norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm
norm = colors.Normalize(vmin=np.nanmin(c), vmax=np.nanmax(c)) if norm is None else norm
except ValueError as e:
raise ValueError(
"Could not convert values in the `color` column to float, if `color` column represents"
" categories, set the column to categorical dtype."
) from e
c = cmap(norm(c))
# normalize only the not nan values, else the whole array would contain only nan values
c[~c.isnull()] = norm(c[~c.isnull()])
c = cmap(c)

fill_c = ColorConverter().to_rgba_array(c)
fill_c[..., -1] *= render_params.fill_alpha
Expand Down Expand Up @@ -796,6 +798,9 @@ def _set_color_source_vec(

# do not rename categories, as colors need not be unique
color_vector = color_source_vector.map(color_mapping)
# nan handling
color_vector = color_vector.add_categories(na_color)
color_vector[pd.isna(color_vector)] = na_color

return color_source_vector, color_vector, True

Expand All @@ -819,15 +824,18 @@ def _map_color_seg(

if pd.api.types.is_categorical_dtype(color_vector.dtype):
# Case A: users wants to plot a categorical column
if np.any(color_source_vector.isna()):
cell_id[color_source_vector.isna()] = 0
Comment on lines -822 to -823
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In seg, the value 0 depicts the background, so this would lead to the bg being mapped to the NaN color
The actual label(s) with na in the color_source_vector don't have their id in cell_id anymore, so they're mapped to nothing! => would look like background

val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
cols = colors.to_rgba_array(color_vector.categories)
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
# Case B: user wants to plot a continous column
if isinstance(color_vector, pd.Series):
color_vector = color_vector.to_numpy()
cols = cmap_params.cmap(cmap_params.norm(color_vector))
# normalize only the not nan values, else the whole array would contain only nan values
normed_color_vector = color_vector.copy().astype(float)
normed_color_vector[~np.isnan(normed_color_vector)] = cmap_params.norm(
normed_color_vector[~np.isnan(normed_color_vector)]
)
cols = cmap_params.cmap(normed_color_vector)
val_im = map_array(seg.copy(), cell_id, cell_id)
else:
# Case C: User didn't specify any colors
Expand Down Expand Up @@ -2086,7 +2094,7 @@ def _validate_image_render_params(
def _get_wanted_render_elements(
sdata: SpatialData,
sdata_wanted_elements: list[str],
params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams),
params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams,
cs: str,
element_type: Literal["images", "labels", "points", "shapes"],
) -> tuple[list[str], list[str], bool]:
Expand Down Expand Up @@ -2243,7 +2251,7 @@ def _create_image_from_datashader_result(


def _datashader_aggregate_with_function(
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
cvs: Canvas,
spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame,
col_for_color: str | None,
Expand Down Expand Up @@ -2307,7 +2315,7 @@ def _datashader_aggregate_with_function(


def _datshader_get_how_kw_for_spread(
reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None),
reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None,
) -> str:
# Get the best input for the how argument of ds.tf.spread(), needed for numerical values
reduction = reduction or "sum"
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 38 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,44 @@ def test_sdata_multiple_images_diverging_dims():
return sdata


@pytest.fixture
def sdata_blobs_points_with_nans_in_table() -> SpatialData:
"""Get blobs sdata where the table annotates the points and includes nan values"""
blob = blobs()
n_obs = len(blob["blobs_points"])
adata = AnnData(RNG.normal(size=(n_obs, 2)))
adata.X[0:30, 0] = np.nan
adata.var = pd.DataFrame({}, index=["col1", "col2"])
adata.obs = pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["cola", "colb", "colc"])
adata.obs.loc[0:30, "cola"] = np.nan
adata.obs["instance_id"] = np.arange(adata.n_obs)
adata.obs["category"] = pd.Series(["a", "b", np.nan] * 50, dtype="category")
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_points"
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_points")
blob["table"] = table
return blob


@pytest.fixture
def sdata_blobs_shapes_with_nans_in_table() -> SpatialData:
"""Get blobs sdata where the table annotates the shapes and includes nan values"""
blob = blobs()
n_obs = len(blob["blobs_polygons"])
adata = AnnData(RNG.normal(size=(n_obs, 2)))
adata.X[0, 0] = np.nan
adata.var = pd.DataFrame({}, index=["col1", "col2"])
adata.obs = pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["cola", "colb", "colc"])
adata.obs.loc[0, "cola"] = np.nan
adata.obs["instance_id"] = np.arange(adata.n_obs)
adata.obs["category"] = pd.Series(["a", "b", np.nan, "c", "a"], dtype="category")
adata.obs["instance_id"] = list(range(adata.n_obs))
adata.obs["region"] = "blobs_polygons"
table = TableModel.parse(adata=adata, region_key="region", instance_key="instance_id", region="blobs_polygons")
blob["table"] = table
return blob


@pytest.fixture
def sdata_blobs_shapes_annotated() -> SpatialData:
"""Get blobs sdata with continuous annotation of polygons."""
Expand Down
12 changes: 12 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,15 @@ def test_plot_can_color_with_norm_no_clipping(self, sdata_blobs: SpatialData):
def test_plot_can_annotate_labels_with_table_layer(self, sdata_blobs: SpatialData):
sdata_blobs["table"].layers["normalized"] = RNG.random(sdata_blobs["table"].X.shape)
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum", table_layer="normalized").pl.show()

def test_plot_can_annotate_labels_with_nan_in_table_obs_categorical(self, sdata_blobs: SpatialData):
sdata_blobs.table.obs["cat_color"] = pd.Categorical(["a", "b", "b", "a", "b"] * 5 + [np.nan])
sdata_blobs.pl.render_labels("blobs_labels", color="cat_color").pl.show()

def test_plot_can_annotate_labels_with_nan_in_table_obs_continuous(self, sdata_blobs: SpatialData):
sdata_blobs["table"].obs["cont_color"] = [np.nan, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] * 2
sdata_blobs.pl.render_labels("blobs_labels", color="cont_color").pl.show()

def test_plot_can_annotate_labels_with_nan_in_table_X_continuous(self, sdata_blobs: SpatialData):
sdata_blobs["table"].X[0:5, 0] = np.nan
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()
Loading
Loading