Source code for smx.pipeline

"""
SMX: high-level facade for the full SMX explanation pipeline.

This class internalises the seed-loop orchestration that every caller would
otherwise have to rewrite manually (zone extraction → predicate generation →
bagging → metric → graph → LRC → natural-scale mapping across N seeds).

Individual component classes (``ZoneAggregator``, ``PredicateGenerator``, etc.)
remain available for power users who need fine-grained control.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

import networkx as nx
import numpy as np
import pandas as pd

from smx.zones.extraction import extract_spectral_zones
from smx.zones.aggregation import ZoneAggregator
from smx.predicates.generation import PredicateGenerator
from smx.predicates.bagging import PredicateBagger
from smx.predicates.metrics import CovarianceMetric, PerturbationMetric
from smx.graph.builder import PredicateGraphBuilder
from smx.graph.centrality import compute_lrc, aggregate_lrc_across_seeds
from smx.graph.interpretation import map_thresholds_to_natural
from smx.evaluation.faithfulness import progressive_masking_faithfulness

logger = logging.getLogger(__name__)

SpectralCuts = List[tuple]   # list of (name, start, end) or (name, start, end, group)


[docs] class SMX: """Full SMX explanation pipeline as a single fit/transform object. Runs zone extraction → PCA aggregation → predicate generation → seed-loop (bagging → metric → graph → LRC) → cross-seed aggregation → natural-scale threshold mapping. Parameters ---------- spectral_cuts : list of (name, start, end) tuples Zone definitions, e.g. ``[("Low", 1.0, 4.0), ("High", 4.0, 10.0)]``. quantiles : list of float Quantile fractions for predicate generation, e.g. ``[0.25, 0.5, 0.75]``. n_repetitions : int, default 4 Number of independent bagging repetitions. Seeds are generated as ``[0, 1, …, n_repetitions-1]``. n_bags : int, default 10 Number of bags per seed. n_samples_fraction : float, default 0.8 Fraction of calibration samples drawn per bag. The minimum samples per predicate is hardcoded to 20 % of the dataset. replace : bool, default False Whether to sample bags with replacement. metric : {'covariance', 'perturbation'}, default 'perturbation' Importance metric to use. estimator : sklearn-compatible estimator, optional Trained model required when ``metric='perturbation'``. perturbation_mode : str, default 'median' Replacement strategy for perturbation (``'constant'``, ``'mean'``, ``'median'``, ``'min'``, ``'max'``). perturbation_value : float, default 0 Constant replacement value used when ``perturbation_mode='constant'``. perturbation_metric : str, default 'probability_shift' Perturbation importance measure. Determines how the impact of spectral zone perturbation is quantified. Choice depends on the estimator type and the desired sensitivity: **Classification estimators** (with ``predict_proba``): - ``'probability_shift'`` — Mean total variation distance between pre- and post-perturbation class probabilities. Sensitive to confidence changes across all classes. Requires ``predict_proba``. - ``'accuracy_drop'`` — Drop in accuracy when perturbed predictions are compared to original predictions. - ``'f1_drop'`` — Weighted F1-score decrease after perturbation. - ``'decision_function_shift'`` — Mean absolute change in decision function values (e.g. signed distances from hyperplane for SVC). Requires ``decision_function()``. **Regression estimators** (with ``predict`` returning continuous values): - ``'mean_abs_diff'`` — Mean absolute difference between original and perturbed predictions. - ``'mean_diff'`` — Mean signed difference (bias direction). Positive values indicate perturbation increases predictions, negative decreases. - ``'mean_relative_dev'`` — Mean relative deviation, normalized by original prediction magnitude. Treats zero predictions as NaN. normalize_by_zone_size : bool, default True Divide raw perturbation importance by zone width. zone_size_exponent : float, default 1.0 Exponent applied to zone size during normalisation. covariance_threshold : float, default 0.01 Minimum covariance value to keep a predicate (covariance metric only). var_exp : bool, default True Weight graph edges by PC1 explained variance of the source zone. show_graph_details : bool, default False Print bidirectional-edge details during graph construction. class_threshold : float, default 0.5 Decision boundary for ``Class_Predicted`` annotation on bags. Attributes (set after :meth:`fit`) ------------------------------------ lrc_natural\_ : pd.DataFrame or None LRC with natural-scale thresholds (available only when ``X_cal_natural`` is provided to :meth:`fit`). Columns: ``Node``, ``Local_Reaching_Centrality``, ``Zone``, ``Threshold``, ``Operator``, ``Threshold_Natural``. lrc_summed\_ : pd.DataFrame Mean-aggregated LRC across seeds, preprocessed-scale thresholds. lrc_summed_unique\_ : pd.DataFrame Zone-deduplicated version of *lrc_summed_* (one row per zone). zone_scores\_ : pd.DataFrame PCA zone scores on the preprocessed calibration data. predicates_df\_ : pd.DataFrame Full predicate catalogue (generated from *zone_scores_*). pca_info\_ : dict PCA info for the preprocessed zones. pca_info_natural\_ : dict or None PCA info for the natural (unpreprocessed) zones (only when ``X_cal_natural`` is provided to :meth:`fit`). zones_natural\_ : dict or None Raw spectral zone DataFrames from the unpreprocessed data (only when ``X_cal_natural`` is provided to :meth:`fit`). graphs_by_seed\_ : dict[int, nx.DiGraph] Per-seed directed predicate graphs (useful for debugging). valid_seeds\_ : list[int] Seeds that produced a non-empty graph (subset of ``seeds``). faithfulness\_ : dict Progressive top-k masking evaluation summary produced by :meth:`evaluate_faithfulness`. """ def __init__( self, spectral_cuts: SpectralCuts, quantiles: List[float], n_repetitions: int = 4, n_bags: int = 10, n_samples_fraction: float = 0.8, replace: bool = False, metric: Literal["covariance", "perturbation"] = "perturbation", estimator: Optional[Any] = None, perturbation_mode: str = "median", perturbation_value: float = 0, perturbation_metric: str = "probability_shift", perturbation_stats_source: str = "full", normalize_by_zone_size: bool = True, zone_size_exponent: float = 1.0, covariance_threshold: float = 0.01, var_exp: bool = True, show_graph_details: bool = False, class_threshold: float = 0.5, ) -> None: if metric not in ("covariance", "perturbation"): raise ValueError(f"metric must be 'covariance' or 'perturbation', got '{metric}'.") if metric == "perturbation" and estimator is None: raise ValueError("estimator is required when metric='perturbation'.") self.spectral_cuts = spectral_cuts self.quantiles = quantiles self.n_repetitions = n_repetitions self.seeds = list(range(n_repetitions)) self.n_bags = n_bags self.n_samples_fraction = n_samples_fraction self.replace = replace self.metric = metric self.estimator = estimator self.perturbation_mode = perturbation_mode self.perturbation_value = perturbation_value self.perturbation_metric = perturbation_metric self.perturbation_stats_source = perturbation_stats_source self.normalize_by_zone_size = normalize_by_zone_size self.zone_size_exponent = zone_size_exponent self.covariance_threshold = covariance_threshold self.var_exp = var_exp self.show_graph_details = show_graph_details self.class_threshold = class_threshold # Result attributes — populated by fit() self.lrc_natural_: Optional[pd.DataFrame] = None self.lrc_summed_: Optional[pd.DataFrame] = None self.lrc_summed_unique_: Optional[pd.DataFrame] = None self.zone_scores_: Optional[pd.DataFrame] = None self.predicates_df_: Optional[pd.DataFrame] = None self.pca_info_: Optional[Dict] = None self.pca_info_natural_: Optional[Dict] = None self.zones_natural_: Optional[Dict] = None self.graphs_by_seed_: Dict[int, nx.DiGraph] = {} self.valid_seeds_: List[int] = [] self.faithfulness_: Optional[Dict[str, Any]] = None # ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------
[docs] def fit( self, X_cal_prep: pd.DataFrame, y_pred_cal: Union[pd.Series, np.ndarray], X_cal_natural: Optional[pd.DataFrame] = None, ) -> "SMX": """Run the full SMX explanation pipeline. Parameters ---------- X_cal_prep : pd.DataFrame Pre-processed calibration spectra (samples × features). y_pred_cal : pd.Series or np.ndarray Continuous model predictions for the calibration set (aligned with *X_cal_prep*). X_cal_natural : pd.DataFrame, optional Unpreprocessed calibration spectra with the same shape as *X_cal_prep*. Required for ``lrc_natural_`` threshold mapping. When ``None``, the natural-scale mapping step is skipped and ``lrc_natural_``, ``zones_natural_``, and ``pca_info_natural_`` remain ``None``. Returns ------- self """ y_pred = ( pd.Series(y_pred_cal.values) if isinstance(y_pred_cal, pd.Series) else pd.Series(y_pred_cal) ) n_cal = len(X_cal_prep) # ── Step 1: zone extraction + PCA aggregation ───────────────────── logger.debug("Extracting spectral zones…") zones_prep = extract_spectral_zones(X_cal_prep, self.spectral_cuts) aggregator = ZoneAggregator(method="pca") zone_scores = aggregator.fit_transform(zones_prep) pca_info = aggregator.pca_info_ self.zone_scores_ = zone_scores self.pca_info_ = pca_info # ── Step 2: predicate generation ───────────────────────────────── logger.debug("Generating predicates…") gen = PredicateGenerator(quantiles=self.quantiles) gen.fit(zone_scores) predicates_df = gen.predicates_df_ self.predicates_df_ = predicates_df metric_column = "Covariance" if self.metric == "covariance" else "Perturbation" # ── Step 3: seed loop ──────────────────────────────────────────── lrc_by_seed: Dict[int, pd.DataFrame] = {} graphs_by_seed: Dict[int, nx.DiGraph] = {} for seed in self.seeds: logger.debug("Seed %d — bagging…", seed) # 3a. Bagging bagger = PredicateBagger( n_bags=self.n_bags, n_samples_fraction=self.n_samples_fraction, replace=self.replace, sample_bagging=True, predicate_bagging=False, random_seed=seed, ) bags = bagger.run(zone_scores, y_pred, predicates_df) # Annotate bags with discrete class prediction for pred_dict in bags.values(): for df_info in pred_dict.values(): df_info["Class_Predicted"] = np.where( df_info["Predicted_Y"] >= self.class_threshold, "A", "B" ) # 3b. Metric logger.debug("Seed %d — computing %s metric…", seed, self.metric) if self.metric == "covariance": metric_obj = CovarianceMetric( metric="covariance", threshold=self.covariance_threshold, ) else: metric_obj = PerturbationMetric( estimator=self.estimator, Xcalclass_prep=X_cal_prep, predicates_df=predicates_df, spectral_cuts=self.spectral_cuts, perturbation_mode=self.perturbation_mode, perturbation_value=self.perturbation_value, stats_source=self.perturbation_stats_source, metric=self.perturbation_metric, normalize_by_zone_size=self.normalize_by_zone_size, zone_size_exponent=self.zone_size_exponent, ) rankings = metric_obj.compute(bags) # 3c. Graph logger.debug("Seed %d — building predicate graph…", seed) builder = PredicateGraphBuilder( random_state=seed, show_details=self.show_graph_details, var_exp=self.var_exp, pca_info_dict=pca_info if self.var_exp else None, ) graph = builder.build(bags, rankings, metric_column=metric_column) graphs_by_seed[seed] = graph # 3d. LRC predicate_nodes = [ n for n, attr in graph.nodes(data=True) if attr.get("node_type") == "predicate" ] if len(predicate_nodes) < 1 or graph.number_of_nodes() < 2: logger.warning( "Seed %d produced an undersized graph (%s, nodes=%d, predicate_nodes=%d) — skipping.", seed, self.metric, graph.number_of_nodes(), len(predicate_nodes) ) continue lrc_df_seed = compute_lrc(graph, predicates_df) if lrc_df_seed.empty: logger.warning( "Seed %d produced an empty LRC table after graph processing (%s) — skipping.", seed, self.metric, ) continue lrc_df_seed["Seed"] = seed lrc_by_seed[seed] = lrc_df_seed if not lrc_by_seed: raise RuntimeError( f"All seeds produced empty graphs for metric='{self.metric}'. " "The model predictions may be degenerate (e.g. all on one side)." ) self.graphs_by_seed_ = graphs_by_seed self.valid_seeds_ = list(lrc_by_seed.keys()) # ── Step 4: aggregate across seeds ─────────────────────────────── logger.debug("Aggregating LRC across %d valid seeds…", len(self.valid_seeds_)) lrc_summed, lrc_summed_unique = aggregate_lrc_across_seeds( lrc_by_seed, self.valid_seeds_ ) self.lrc_summed_ = lrc_summed self.lrc_summed_unique_ = lrc_summed_unique # ── Step 5: map thresholds to natural scale (optional) ─────────── if X_cal_natural is not None: logger.debug("Mapping thresholds to natural scale…") zones_natural = extract_spectral_zones(X_cal_natural, self.spectral_cuts) agg_natural = ZoneAggregator(method="pca") zone_scores_natural = agg_natural.fit_transform(zones_natural) self.zones_natural_ = zones_natural self.pca_info_natural_ = agg_natural.pca_info_ self.lrc_natural_ = map_thresholds_to_natural( lrc_df=lrc_summed, zone_sums_preprocessed=zone_scores, zone_sums_natural=zone_scores_natural, ) else: logger.info( "X_cal_natural was not provided to SMX.fit(); skipping natural-scale threshold mapping. " "Preprocessed-scale outputs (lrc_summed_, lrc_summed_unique_) remain available." ) self.lrc_natural_ = None self.zones_natural_ = None self.pca_info_natural_ = None return self
[docs] def evaluate_faithfulness( self, X_eval: pd.DataFrame, *, ranking: Literal["unique", "summed", "natural"] = "unique", X_reference: Optional[pd.DataFrame] = None, metric: Literal[ "auto", "probability_shift", "mean_abs_diff", "decision_function_shift", ] = "auto", masking_strategy: Literal["zero", "constant", "mean", "median", "min", "max"] = "zero", constant_value: float = 0.0, max_k: Optional[int] = None, n_random_rankings: int = 100, random_state: Optional[int] = 42, output_path: Optional[Union[str, "Path"]] = None, plot_title: Optional[str] = None, plot_width: int = 1100, plot_height: int = 560, ) -> Dict[str, Any]: """Evaluate SMX faithfulness via progressive top-k zone masking. The ranked spectral zones are progressively masked on *X_eval* following the selected SMX ranking, and the resulting prediction shift is summarized by the area under the masking curve (AUC). Parameters ---------- X_eval : pd.DataFrame Evaluation spectra to be masked progressively. ranking : {'unique', 'summed', 'natural'}, default 'unique' Ranking table used to derive the ordered list of spectral zones. ``'unique'`` uses the one-zone-per-row ranking in :attr:`lrc_summed_unique_`. ``'summed'`` and ``'natural'`` are deduplicated internally to one row per zone before masking. X_reference : pd.DataFrame, optional Reference spectra used to compute replacement values for non-zero masking strategies. Defaults to *X_eval*. metric : {'auto', 'probability_shift', 'mean_abs_diff', 'decision_function_shift'}, default 'auto' Prediction-shift metric to evaluate. ``'auto'`` chooses ``'probability_shift'`` when the estimator exposes ``predict_proba()``, ``'decision_function_shift'`` when it exposes ``decision_function()``, otherwise ``'mean_abs_diff'``. masking_strategy : {'zero', 'constant', 'mean', 'median', 'min', 'max'}, default 'zero' How masked spectral variables are replaced. constant_value : float, default 0.0 Replacement value used when ``masking_strategy='constant'``. max_k : int, optional Maximum number of ranked zones to mask. Defaults to all ranked zones available in *X_eval*. n_random_rankings : int, default 100 Number of random rankings used to contextualize the observed AUC. random_state : int, optional Seed controlling the random baseline. output_path : str or Path, optional If provided, also export a faithfulness plot to this path. The extension determines the format (``.html`` or a static image). plot_title : str, optional Title override used when *output_path* is provided. plot_width : int, default 1100 Plot width in pixels. Used only when *output_path* is provided. plot_height : int, default 560 Plot height in pixels. Used only when *output_path* is provided. Returns ------- dict Faithfulness summary including ``curve_df``, ``auc``, ``auc_normalized``, ``level``, and null-baseline statistics. """ if self.estimator is None: raise RuntimeError( "SMX requires a fitted estimator to evaluate faithfulness." ) ranking_map = { "unique": self.lrc_summed_unique_, "summed": self.lrc_summed_, "natural": self.lrc_natural_, } if ranking not in ranking_map: raise ValueError("ranking must be 'unique', 'summed', or 'natural'.") ranking_df = ranking_map[ranking] if ranking_df is None or ranking_df.empty: raise RuntimeError( f"No ranking data is available for ranking='{ranking}'. Fit SMX before " "calling evaluate_faithfulness()." ) result = progressive_masking_faithfulness( estimator=self.estimator, X_eval=X_eval, spectral_cuts=self.spectral_cuts, ranking_df=ranking_df, X_reference=X_reference, metric=metric, masking_strategy=masking_strategy, constant_value=constant_value, max_k=max_k, n_random_rankings=n_random_rankings, random_state=random_state, ) result["ranking_source"] = ranking if output_path is not None: from smx.plotting import plot_faithfulness_curve plot_faithfulness_curve( faithfulness_result=result, output_path=output_path, title=plot_title, width=plot_width, height=plot_height, ) result["plot_path"] = str(output_path) self.faithfulness_ = result return result
[docs] def plot_zone_ranking_over_spectrum( self, output_path: Union[str, Path], *, ranking: Literal["unique", "natural"] = "unique", aggregation: Literal["mean", "median"] = "mean", title: Optional[str] = None, X_natural: Optional[pd.DataFrame] = None, y_labels: Optional["pd.Series"] = None, class_colors: Optional[dict] = None, width: Optional[int] = 1200, height: Optional[int] = 500, ) -> pd.DataFrame: """Plot ranked spectral zones over a reference spectrum and save to file. The output format is inferred from *output_path* — ``.html`` for an interactive figure, or ``.png`` / ``.svg`` / ``.pdf`` for a static image (requires ``kaleido``). Parameters ---------- output_path : str or Path Destination ``.html`` file. ranking : {'unique', 'natural'}, default 'unique' Ranking source. ``'unique'`` uses ``lrc_summed_unique_`` (one row per zone). ``'natural'`` uses ``lrc_natural_`` and collapses multiple predicates per zone to the strongest LRC value. aggregation : {'mean', 'median'}, default 'mean' Aggregation used to build the reference spectrum from ``zones_natural_``. title : str, optional Figure title override. X_natural : pd.DataFrame, optional Full calibration dataset in natural (unpreprocessed) units. When provided together with *y_labels*, a mean spectrum is drawn for each class on top of the overall reference spectrum. y_labels : pd.Series, optional Class labels aligned with the rows of *X_natural*. Required when *X_natural* is given. class_colors : dict[str, str], optional Mapping from class label to hex/CSS color string. Missing labels fall back to a built-in palette. width : int, default 1200 Figure width in pixels. Used only for static image exports. height : int, default 500 Figure height in pixels. Used only for static image exports. Returns ------- pd.DataFrame Normalized ranking table used in the figure. """ from smx.plotting import plot_zone_ranking_over_spectrum if self.lrc_summed_ is None: raise RuntimeError( "SMX must be fitted before calling plot_zone_ranking_over_spectrum." ) if self.zones_natural_ is None: raise RuntimeError( "SMX was fitted without X_cal_natural, so no natural-scale reference spectrum is available for plotting. " "Re-run SMX.fit(..., X_cal_natural=...) to enable plotting over the natural spectrum." ) if ranking == "unique": ranking_df = self.lrc_summed_unique_ elif ranking == "natural": ranking_df = self.lrc_natural_ else: raise ValueError("ranking must be 'unique' or 'natural'.") if ranking == "natural" and self.lrc_natural_ is None: raise RuntimeError( "Natural-scale ranking was requested (ranking='natural'), but SMX.fit() was called without X_cal_natural. " "Re-run SMX.fit(..., X_cal_natural=...) to compute natural-scale thresholds." ) if ranking_df is None or ranking_df.empty: raise RuntimeError( "No ranking information is available. Fit SMX successfully before plotting." ) class_spectra = None if X_natural is not None and y_labels is not None: class_spectra = { str(cls): X_natural[y_labels.values == cls] for cls in y_labels.unique() } return plot_zone_ranking_over_spectrum( zone_ranking_df=ranking_df, spectral_cuts=self.spectral_cuts, reference_spectrum=self.zones_natural_, output_path=output_path, aggregation=aggregation, title=title or "SMX zone ranking over spectrum", class_spectra=class_spectra, class_colors=class_colors, width=width, height=height, )
[docs] def plot_faithfulness( self, output_path: Union[str, "Path"], *, title: Optional[str] = None, width: int = 1100, height: int = 560, ) -> pd.DataFrame: """Plot the progressive masking faithfulness curve saved in ``faithfulness_``. Parameters ---------- output_path : str or Path Destination file. Use ``.html`` for an interactive figure or an image extension for static export. title : str, optional Figure title override. width : int, default 1100 Figure width in pixels. Used only for static image exports. height : int, default 560 Figure height in pixels. Used only for static image exports. Returns ------- pd.DataFrame Faithfulness masking curve used in the figure. """ from smx.plotting import plot_faithfulness_curve if self.faithfulness_ is None: raise RuntimeError( "No faithfulness result is available. Call evaluate_faithfulness() " "before plot_faithfulness()." ) return plot_faithfulness_curve( faithfulness_result=self.faithfulness_, output_path=output_path, title=title, width=width, height=height, )