Source code for smx.plotting.zones

"""
Plot zone-level ranking overlays on top of a reference spectrum.

The main entry point, :func:`plot_zone_ranking_over_spectrum`, accepts either:

* a precomputed ranking DataFrame with ``zone`` / ``score`` / ``rank`` columns
* an SMX LRC table with ``Zone`` / ``Local_Reaching_Centrality`` columns

and writes an HTML Plotly figure where each spectral zone is highlighted as a
ranked band over the reference spectrum.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, Iterable, Optional, Union

import numpy as np
import pandas as pd

from smx.plotting.theme import DEFAULT_THEME, SMXTheme, blend_with_white, build_blended_colorscale


def _prepare_zone_ranking_df(zone_ranking_df: pd.DataFrame) -> pd.DataFrame:
    """Normalize supported ranking-table shapes into zone / score / rank form."""
    if zone_ranking_df is None or zone_ranking_df.empty:
        raise ValueError("zone_ranking_df must be a non-empty DataFrame.")

    ranking_df = zone_ranking_df.copy()

    if {"zone", "score"}.issubset(ranking_df.columns):
        normalized = ranking_df.rename(columns={"zone": "zone", "score": "score"}).copy()
    elif {"Zone", "Local_Reaching_Centrality"}.issubset(ranking_df.columns):
        # LRC tables may contain multiple predicates per zone. Collapse them to
        # a single score per zone using the strongest centrality observed.
        normalized = (
            ranking_df.groupby("Zone", as_index=False)["Local_Reaching_Centrality"]
            .max()
            .rename(columns={"Zone": "zone", "Local_Reaching_Centrality": "score"})
        )
    else:
        raise ValueError(
            "zone_ranking_df must contain either "
            "('zone', 'score') or ('Zone', 'Local_Reaching_Centrality') columns."
        )

    normalized["zone"] = normalized["zone"].astype(str)
    normalized["score"] = pd.to_numeric(normalized["score"], errors="coerce")
    normalized = normalized.dropna(subset=["score"])
    normalized = normalized.sort_values("score", ascending=False).reset_index(drop=True)

    if "rank" in ranking_df.columns and {"zone", "score"}.issubset(ranking_df.columns):
        normalized["rank"] = pd.to_numeric(ranking_df.loc[normalized.index, "rank"], errors="coerce")
        if normalized["rank"].isna().any():
            normalized["rank"] = np.arange(1, len(normalized) + 1)
    else:
        normalized["rank"] = np.arange(1, len(normalized) + 1)

    return normalized[["zone", "score", "rank"]]


def _aggregate_spectrum_df(
    spectrum_df: pd.DataFrame,
    aggregation: str,
) -> pd.Series:
    if spectrum_df.empty:
        raise ValueError("Reference spectrum DataFrame is empty.")

    if aggregation == "mean":
        spectrum = spectrum_df.mean(axis=0)
    elif aggregation == "median":
        spectrum = spectrum_df.median(axis=0)
    else:
        raise ValueError("aggregation must be 'mean' or 'median'.")

    spectrum.index = pd.to_numeric(spectrum.index.astype(str), errors="coerce")
    spectrum = spectrum[~spectrum.index.isna()]
    return spectrum.sort_index()


def _build_reference_spectrum(
    reference_spectrum: Union[pd.Series, pd.DataFrame, Dict[str, pd.DataFrame]],
    spectral_cuts: Iterable,
    aggregation: str,
) -> pd.Series:
    if isinstance(reference_spectrum, pd.Series):
        spectrum = reference_spectrum.copy()
        spectrum.index = pd.to_numeric(spectrum.index.astype(str), errors="coerce")
        spectrum = spectrum[~spectrum.index.isna()]
        return spectrum.sort_index()

    if isinstance(reference_spectrum, pd.DataFrame):
        return _aggregate_spectrum_df(reference_spectrum, aggregation=aggregation)

    if isinstance(reference_spectrum, dict):
        series_parts = []
        seen_x = set()
        for cut in spectral_cuts:
            if isinstance(cut, dict):
                zone_name = str(cut["name"])
            else:
                zone_name = str(cut[0]) if len(cut) == 3 else f"{cut[0]}-{cut[1]}"

            zone_df = reference_spectrum.get(zone_name)
            if zone_df is None or zone_df.empty:
                continue
            zone_series = _aggregate_spectrum_df(zone_df, aggregation=aggregation)
            zone_series = zone_series[~zone_series.index.isin(seen_x)]
            seen_x.update(zone_series.index.tolist())
            series_parts.append(zone_series)

        if not series_parts:
            raise ValueError("Could not build a reference spectrum from the provided zone dictionary.")

        return pd.concat(series_parts).sort_index()

    raise TypeError(
        "reference_spectrum must be a pandas Series, pandas DataFrame, "
        "or dict[str, pandas.DataFrame]."
    )


[docs] def plot_spectrum_with_zones( spectrum: Union[pd.Series, pd.DataFrame, np.ndarray], spectral_cuts: Iterable, identified_peaks: Optional[Iterable[int]] = None, identified_minima: Optional[Iterable[int]] = None, output_path: Optional[Union[str, Path]] = None, *, title: Optional[str] = None, zone_color: str = "rgb(173, 216, 230)", background_color: str = "rgb(0, 34, 75)", width: Optional[int] = 1200, height: Optional[int] = 500, theme: Optional[SMXTheme] = None, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Plot a spectrum with spectral zones highlighted in the background. The figure is always displayed; set *return_df=True* to return the normalized cuts DataFrame. Parameters ---------- spectrum : pandas.Series, pandas.DataFrame, or numpy.ndarray Spectrum values. If a DataFrame is provided, the first row is used. spectral_cuts : iterable Zone definitions as ``(label, start, end)`` tuples or dicts. identified_peaks : iterable of int, optional Indices of local maxima to mark on the plot. output_path : str or Path, optional Destination file for the plot (HTML or static image). If omitted, the figure is displayed inline. title : str, optional Plot title. zone_color : str, default "rgb(173, 216, 230)" Light blue fill color for spectral zones. background_color : str, default "rgb(211, 211, 211)" Light gray fill color for background zones. width : int, default 1200 Figure width in pixels for static image export. height : int, default 500 Figure height in pixels for static image export. theme : SMXTheme, optional Visual theme controlling fonts and line styles. return_df : bool, default False If ``True``, return the normalized cuts DataFrame. Returns ------- pd.DataFrame or None Normalized cuts DataFrame when *return_df* is True. """ try: import plotly.graph_objects as go except ImportError as exc: raise ImportError( "plotly is required for plot_spectrum_with_zones. " "Install it with: pip install plotly" ) from exc theme = theme or DEFAULT_THEME if isinstance(spectrum, pd.DataFrame): if spectrum.empty: raise ValueError("spectrum DataFrame is empty.") spectrum_series = spectrum.iloc[0, :] elif isinstance(spectrum, pd.Series): spectrum_series = spectrum else: spectrum_series = pd.Series(np.asarray(spectrum, dtype=float)) if spectrum_series.empty: raise ValueError("spectrum is empty.") raw_index = spectrum_series.index.astype(str) x_numeric = pd.to_numeric(raw_index, errors="coerce") if x_numeric.isna().any(): x_values = np.arange(len(spectrum_series), dtype=float) label_map = {str(i): float(i) for i in range(len(spectrum_series))} else: x_values = x_numeric.to_numpy(dtype=float) label_map = {str(label): float(val) for label, val in zip(raw_index, x_values)} y_values = spectrum_series.to_numpy(dtype=float) def _to_x(value: Union[int, float, str]) -> float: """Convert a cut boundary to a numeric x position.""" try: return float(value) except Exception: return float(label_map.get(str(value), 0.0)) def _rgba(rgb: str, opacity: float) -> str: """Convert an rgb string into rgba with the desired opacity.""" rgb_vals = [int(v) for v in rgb.strip().replace("rgb(", "").replace(")", "").split(",")] return f"rgba({rgb_vals[0]}, {rgb_vals[1]}, {rgb_vals[2]}, {opacity})" cut_rows = [] for cut in spectral_cuts: if isinstance(cut, dict): label = str(cut.get("name", "zone")) start = cut.get("start") end = cut.get("end") elif len(cut) == 3: label, start, end = cut label = str(label) elif len(cut) == 2: start, end = cut label = f"{start}-{end}" else: raise ValueError("Each cut must have 2 or 3 elements, or dict form.") x_start = _to_x(start) x_end = _to_x(end) if x_start > x_end: x_start, x_end = x_end, x_start cut_rows.append({"label": label, "start": x_start, "end": x_end}) cuts_df = pd.DataFrame(cut_rows).sort_values("start").reset_index(drop=True) fig = go.Figure() fig.add_trace( go.Scatter( x=x_values, y=y_values, mode="lines", line=dict( color=theme.reference_line_color, width=theme.reference_line_width, dash="solid", ), name="Spectrum", ) ) # Fixed 0.6 opacity for zone/background shaded regions (independent of theme). _ZONE_OPACITY = 0.6 zone_rgba = _rgba(zone_color, _ZONE_OPACITY) background_rgba = _rgba(background_color, _ZONE_OPACITY) for _, row in cuts_df.iterrows(): label = row["label"] start = float(row["start"]) end = float(row["end"]) is_background = "background" in label.lower() fillcolor = background_rgba if is_background else zone_rgba fig.add_vrect( x0=start, x1=end, fillcolor=fillcolor, opacity=_ZONE_OPACITY, line_width=0, layer="below", ) fig.add_vline( x=start, line=dict( color='black', width=theme.zone_boundary_width, dash=theme.zone_boundary_dash, ), ) # Add legend entries for zone and background shading. fig.add_trace( go.Scatter( x=[None], y=[None], mode="markers", marker=dict(size=14, color=zone_rgba, symbol="square"), name="zones", showlegend=True, ) ) fig.add_trace( go.Scatter( x=[None], y=[None], mode="markers", marker=dict(size=14, color=background_rgba, symbol="square"), name="backgrounds", showlegend=True, ) ) if identified_peaks is not None: fig.add_vline( x=float(cuts_df["end"].iloc[-1]), line=dict( color=theme.zone_boundary_color, width=theme.zone_boundary_width, dash=theme.zone_boundary_dash, ), ) if identified_peaks is not None: peaks_idx = np.asarray(list(identified_peaks), dtype=int) peaks_idx = peaks_idx[(peaks_idx >= 0) & (peaks_idx < len(y_values))] if peaks_idx.size > 0: fig.add_trace( go.Scatter( x=x_values[peaks_idx], y=y_values[peaks_idx], mode="markers", marker=dict(size=8, color="#e41a1c", symbol="circle"), name="identified peaks", ) ) if identified_minima is not None: mins_idx = np.asarray(list(identified_minima), dtype=int) mins_idx = mins_idx[(mins_idx >= 0) & (mins_idx < len(y_values))] if mins_idx.size > 0: fig.add_trace( go.Scatter( x=x_values[mins_idx], y=y_values[mins_idx], mode="markers", marker=dict(size=8, color="#4daf4a", symbol="triangle-up"), name="identified minima", ) ) y_min = float(np.nanmin(y_values)) y_max = float(np.nanmax(y_values)) y_span = y_max - y_min if y_max > y_min else 1.0 fig.update_layout( **theme.plotly_layout( title=title or "Spectrum with spectral zones", xaxis_title="Spectral variables", yaxis_title="Intensity", margin=dict(t=80, r=60, b=90, l=60), legend=dict( orientation="h", yanchor="top", y=1.12, xanchor="center", x=0.5, ), ) ) fig.update_yaxes(range=[y_min - 0.05 * y_span, y_max + 0.12 * y_span]) if output_path is not None: output_path = Path(output_path) suffix = output_path.suffix.lower() if suffix == ".html": fig.write_html(str(output_path)) elif suffix in {".png", ".svg", ".pdf", ".jpg", ".jpeg", ".webp"}: try: fig.write_image(str(output_path), width=width, height=height) except ValueError as exc: raise ImportError( "Static image export requires kaleido. " "Install it with: pip install kaleido" ) from exc else: raise ValueError( f"Unsupported output format '{suffix}'. " "Use '.html' for interactive or '.png'/'svg'/'pdf' for static image." ) fig.show() if return_df: return cuts_df
_DEFAULT_CLASS_COLORS = [ "#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999", ]
[docs] def plot_zone_ranking_over_spectrum( zone_ranking_df: pd.DataFrame, spectral_cuts: Iterable, reference_spectrum: Union[pd.Series, pd.DataFrame, Dict[str, pd.DataFrame]], output_path: Optional[Union[str, Path]], *, aggregation: str = "mean", title: Optional[str] = None, spectrum_name: str = "Reference spectrum", colorscale: str = "YlOrRd", annotation_y: float = 1.06, class_spectra: Optional[Dict[str, Union[pd.Series, pd.DataFrame, Dict[str, pd.DataFrame]]]] = None, class_colors: Optional[Dict[str, str]] = None, width: Optional[int] = 1200, height: Optional[int] = 500, theme: Optional[SMXTheme] = None, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Save a plot showing ranked zones overlaid on a spectrum. The figure is always displayed; set *return_df=True* to return the normalized ranking DataFrame. The output format is inferred from *output_path*: * ``.html`` — interactive Plotly figure (default, no extra dependency) * ``.png``, ``.svg``, ``.pdf``, ``.jpg`` — static image via ``kaleido`` (install with ``pip install kaleido``) Parameters ---------- zone_ranking_df : pd.DataFrame Either a ranking table with ``zone`` / ``score`` / ``rank`` columns or an SMX LRC table with ``Zone`` / ``Local_Reaching_Centrality``. spectral_cuts : iterable Zone definitions as accepted by :class:`smx.pipeline.SMX`. reference_spectrum : pd.Series, pd.DataFrame, or dict[str, pd.DataFrame] Spectrum used as the background line. If a DataFrame is provided, rows are aggregated with ``aggregation``. If a zone dictionary is provided, each zone is aggregated and stitched back together following ``spectral_cuts`` order. output_path : str or Path, optional Destination file. If ``None``, no file is written. aggregation : {'mean', 'median'}, default 'mean' Aggregation used when *reference_spectrum* is a DataFrame or zone dict. title : str, optional Figure title. spectrum_name : str, default 'Reference spectrum' Legend label for the background spectrum. colorscale : str, default 'YlOrRd' Plotly colorscale name used for zone bands. annotation_y : float, default 1.06 Annotation y-position in paper coordinates. class_spectra : dict[str, Series | DataFrame | dict[str, DataFrame]], optional Per-class spectra to overlay. Keys are class labels; values accept the same forms as *reference_spectrum*. Each class is plotted as a separate colored line using ``aggregation`` to collapse rows. class_colors : dict[str, str], optional Hex/CSS color strings keyed by class label. Missing labels fall back to a built-in palette. width : int, default 1200 Figure width in pixels. Used only for static image exports. height : int, default 500 Figure height in pixels. Used only for static image exports. theme : SMXTheme, optional Visual theme controlling colors, fonts, line styles, and the Plotly template. Defaults to :data:`smx.plotting.theme.DEFAULT_THEME`. Explicit style parameters (``colorscale``, ``class_colors``) take precedence over the theme. return_df : bool, default False If ``True``, return the normalized ranking DataFrame. Returns ------- pd.DataFrame or None Normalized ranking DataFrame when *return_df* is True. Notes ----- A vertical colorbar is rendered on the right side of the figure showing the LRC-score-to-color mapping. Its palette is pre-blended with the plot background so it matches the zone band colors exactly. Tick marks are placed at ``score_min`` and ``score_max`` and labeled accordingly. Returns ------- pd.DataFrame Normalized ranking DataFrame used in the plot. """ try: import plotly.graph_objects as go from plotly.colors import sample_colorscale except ImportError as exc: raise ImportError( "plotly is required for plot_zone_ranking_over_spectrum. " "Install it with: pip install plotly" ) from exc theme = theme or DEFAULT_THEME # Explicit params take priority; fall back to theme for unset style values _colorscale = colorscale if colorscale != "YlOrRd" else theme.colorscale ranking_df = _prepare_zone_ranking_df(zone_ranking_df) spectrum = _build_reference_spectrum(reference_spectrum, spectral_cuts, aggregation=aggregation) spectrum = spectrum.dropna() if spectrum.empty: raise ValueError("Reference spectrum is empty after preprocessing.") # Build per-class aggregated spectra when provided class_series: Dict[str, pd.Series] = {} if class_spectra: resolved_colors: Dict[str, str] = {} _used_palette: list = [] for label, src in class_spectra.items(): cs = _build_reference_spectrum(src, spectral_cuts, aggregation=aggregation).dropna() if not cs.empty: class_series[label] = cs resolved_colors[label] = ( (class_colors or {}).get(label) or theme.resolve_class_color(label, _used_palette) ) score_min = float(ranking_df["score"].min()) score_max = float(ranking_df["score"].max()) _VRECT_OPACITY = theme.zone_opacity def _score_to_color(score: float) -> str: if score_max == score_min: norm = 1.0 else: norm = (float(score) - score_min) / (score_max - score_min) return sample_colorscale(_colorscale, [norm])[0] # Colorscale whose colors match the blended zone backgrounds exactly _blended_colorscale = build_blended_colorscale(_colorscale, _VRECT_OPACITY) cut_rows = [] for cut in spectral_cuts: if isinstance(cut, dict): zone_name = str(cut["name"]) start = float(cut["start"]) end = float(cut["end"]) elif len(cut) == 3: zone_name, start, end = cut zone_name = str(zone_name) start = float(start) end = float(end) elif len(cut) == 2: start, end = cut zone_name = f"{start}-{end}" start = float(start) end = float(end) else: raise ValueError("Each spectral cut must have 2 or 3 elements, or dict form.") if start > end: start, end = end, start cut_rows.append({"zone": zone_name, "start": start, "end": end}) cut_df = pd.DataFrame(cut_rows) plot_df = cut_df.merge(ranking_df, on="zone", how="left").sort_values("start").reset_index(drop=True) # Compute y-axis bounds across all spectra that will be drawn all_values = [spectrum.to_numpy(dtype=float)] for cs in class_series.values(): all_values.append(cs.to_numpy(dtype=float)) ymax = float(np.nanmax(np.concatenate(all_values))) ymin = float(np.nanmin(np.concatenate(all_values))) yspan = ymax - ymin if ymax > ymin else 1.0 fig = go.Figure() fig.add_trace( go.Scatter( x=spectrum.index.to_numpy(dtype=float), y=spectrum.to_numpy(dtype=float), mode="lines", line=dict( color=theme.reference_line_color, width=theme.reference_line_width, dash=theme.reference_line_dash, ), name=spectrum_name, ) ) for label, cs in class_series.items(): fig.add_trace( go.Scatter( x=cs.index.to_numpy(dtype=float), y=cs.to_numpy(dtype=float), mode="lines", line=dict(color=resolved_colors[label], width=theme.class_line_width), name=f"Class {label}", ) ) for _, row in plot_df.iterrows(): start = float(row["start"]) end = float(row["end"]) zone_name = row["zone"] score = row.get("score") rank = row.get("rank") color = "rgba(180,180,180,0.15)" if pd.isna(score) else _score_to_color(float(score)) fig.add_vrect( x0=start, x1=end, fillcolor=color, opacity=_VRECT_OPACITY, line_width=0, layer="below", ) fig.add_vline(x=start, line=dict( color=theme.zone_boundary_color, width=theme.zone_boundary_width, dash=theme.zone_boundary_dash, )) midpoint = (start + end) / 2.0 rank_line = f"#{int(rank)}" if pd.notna(rank) else "" score_line = f"{float(score):.3f}" if pd.notna(score) else "" label = "<br>".join(part for part in [rank_line, zone_name, score_line] if part) fig.add_annotation( x=midpoint, y=annotation_y, xref="x", yref="paper", text=label, showarrow=False, align="center", font=dict(size=theme.annotation_font_size, family=theme.font_family), ) fig.add_trace( go.Scatter( x=[midpoint], y=[ymax + 0.04 * yspan], mode="markers", marker=dict(size=10, color=color, opacity=0), name=zone_name, showlegend=False, hovertemplate=( f"Zone: {zone_name}<br>" f"Range: {start:.3f} - {end:.3f}<br>" f"Rank: {int(rank) if pd.notna(rank) else '-'}<br>" f"Score: {float(score):.3f}" if pd.notna(score) else f"Zone: {zone_name}<br>Range: {start:.3f} - {end:.3f}<br>No ranking value" ), ) ) if not plot_df.empty: fig.add_vline( x=float(plot_df["end"].iloc[-1]), line=dict( color=theme.zone_boundary_color, width=theme.zone_boundary_width, dash=theme.zone_boundary_dash, ), ) # Invisible scatter whose sole purpose is to render the score colorbar fig.add_trace( go.Scatter( x=[None], y=[None], mode="markers", marker=dict( colorscale=_blended_colorscale, cmin=score_min, cmax=score_max, color=[score_min], size=0, opacity=0, showscale=True, colorbar=dict( title=dict(text="LRC score", side="right"), thickness=theme.colorbar_thickness, len=theme.colorbar_len, x=1.02, xanchor="left", y=0.5, yanchor="middle", tickmode="array", tickvals=[score_min, score_max], ticktext=[f"{score_min:.3f}<br>(min)", f"{score_max:.3f}<br>(max)"], tickfont=dict(size=10), ), ), hoverinfo="skip", showlegend=False, ) ) fig.update_layout( **theme.plotly_layout( title=title or "Zone ranking over spectrum", xaxis_title="Spectral variables", yaxis_title="Intensity", margin=dict(t=110, r=100, b=90, l=60), legend=dict( orientation="h", yanchor="top", y=-0.24, xanchor="center", x=0.5, ), ) ) fig.update_yaxes(range=[ymin - 0.05 * yspan, ymax + 0.12 * yspan]) if output_path is not None: output_path = Path(output_path) suffix = output_path.suffix.lower() if suffix == ".html": fig.write_html(str(output_path)) elif suffix in {".png", ".svg", ".pdf", ".jpg", ".jpeg", ".webp"}: try: fig.write_image(str(output_path), width=width, height=height) except ValueError as exc: raise ImportError( "Static image export requires kaleido. " "Install it with: pip install kaleido" ) from exc else: raise ValueError( f"Unsupported output format '{suffix}'. " "Use '.html' for interactive or '.png'/'.svg'/'.pdf' for static image." ) fig.show() if return_df: return ranking_df