Source code for smx.plotting.threshold

"""
plot_threshold_spectrum: visualise a multivariate threshold overlaid on
the original spectral zone, coloured by class.

Requires ``plotly``.  The dependency is optional — import errors produce a
clear, actionable message rather than a hard package-load failure.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, Optional, Union

import numpy as np
import pandas as pd

from smx.graph.interpretation import reconstruct_threshold_to_spectrum
from smx.plotting.theme import DEFAULT_THEME, SMXTheme


[docs] def plot_threshold_spectrum( lrc_natural_df: pd.DataFrame, row_index: int, spectral_zones_original: Dict[str, pd.DataFrame], pca_info_dict_original: Dict, y_labels: pd.Series, output_path: Optional[Union[str, Path]], class_colors: Optional[Dict[str, str]] = None, theme: Optional[SMXTheme] = None, width: Optional[int] = 900, height: Optional[int] = 450, return_df: bool = False, ) -> Union[None, pd.Series]: """Reconstruct a threshold to spectrum space and save an HTML plot. The plot overlays the reconstructed multivariate threshold (in red) on top of the individual sample spectra for the chosen spectral zone, coloured by class label. The figure is always displayed; set *return_df=True* to return the threshold spectrum Series. Parameters ---------- lrc_natural_df : pd.DataFrame LRC DataFrame with natural-scale thresholds. Must contain columns ``'Zone'``, ``'Threshold_Natural'``, and ``'Node_Natural'``. row_index : int Row of *lrc_natural_df* to visualise. spectral_zones_original : dict[str, pd.DataFrame] Spectral zones extracted from the *unpreprocessed* calibration data. pca_info_dict_original : dict PCA info from :class:`smx.zones.aggregation.ZoneAggregator` fitted on the natural (unpreprocessed) data. y_labels : pd.Series Class labels aligned with the calibration data rows. output_path : str or Path, optional Destination path for the output ``.html`` file. If ``None``, no file is written. class_colors : dict, optional Mapping of class label → colour string. Explicit values override the theme. Defaults to the theme's ``class_colors``. theme : SMXTheme, optional Visual theme. Defaults to :data:`smx.plotting.theme.DEFAULT_THEME`. return_df : bool, default False If ``True``, return the threshold spectrum Series. Raises ------ ImportError If ``plotly`` is not installed. """ try: import plotly.graph_objects as go except ImportError as exc: raise ImportError( "plotly is required for plot_threshold_spectrum. " "Install it with: pip install plotly" ) from exc theme = theme or DEFAULT_THEME class_colors = class_colors or theme.class_colors zone_name = lrc_natural_df.iloc[row_index]["Zone"] threshold_score = float(lrc_natural_df.iloc[row_index]["Threshold_Natural"]) threshold_spectrum = reconstruct_threshold_to_spectrum( threshold_value=threshold_score, zone_name=zone_name, pca_info_dict=pca_info_dict_original, ) zone_df = spectral_zones_original[zone_name] x_values = pd.to_numeric(zone_df.columns, errors="coerce") fig = go.Figure() seen_classes: set = set() for idx, row in zone_df.iterrows(): class_label = y_labels.iloc[idx] if idx < len(y_labels) else "Unknown" show_legend = class_label not in seen_classes seen_classes.add(class_label) fig.add_trace( go.Scatter( x=x_values, y=row.values, mode="lines", line=dict( color=class_colors.get(class_label, "rgba(128,128,128,0.3)"), width=0.5, ), name=f"Class {class_label}", legendgroup=class_label, showlegend=show_legend, hoverinfo="skip", ) ) fig.add_trace( go.Scatter( x=x_values, y=threshold_spectrum.values, mode="lines", line=dict( color=theme.threshold_color, width=theme.threshold_line_width, dash=theme.threshold_line_dash, ), name=f"Threshold Spectrum ({threshold_spectrum.name})", ) ) node_natural = lrc_natural_df.iloc[row_index].get("Node_Natural", "") fig.update_layout( **theme.plotly_layout( title=f"Zone '{zone_name}' — Multivariate Threshold (Predicate: {node_natural})", xaxis_title="ESpectral variables", yaxis_title="Intensity", showlegend=True, legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), ) ) if output_path is not None: output_path = Path(output_path) suffix = output_path.suffix.lower() if suffix == ".html": fig.write_html(str(output_path)) elif suffix in {".png", ".svg", ".pdf", ".jpg", ".jpeg", ".webp"}: try: fig.write_image(str(output_path), width=width, height=height) except ValueError as exc: raise ImportError( "Static image export requires kaleido. " "Install it with: pip install kaleido" ) from exc else: raise ValueError( f"Unsupported output format '{suffix}'. " "Use '.html' for interactive or '.png'/'.svg'/'.pdf' for static image." ) fig.show() if return_df: return threshold_spectrum