Source code for smx.predicates.metrics

"""
Predicate metric strategy classes.

Each class implements the ``compute(bags_dict) → dict[str, DataFrame]``
interface so they can be swapped transparently in the SMX pipeline.

Available metrics
-----------------
* :class:`CovarianceMetric` — covariance (or mutual information) between
  zone scores and model predictions within each predicate bag.
* :class:`PerturbationMetric` — perturbation-based importance: replace the
  spectral zone of each predicate with a constant/statistic value and measure
  the impact on model predictions.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd


# ---------------------------------------------------------------------------
# Abstract base
# ---------------------------------------------------------------------------


[docs] class BasePredicateMetric(ABC): """Strategy interface for predicate importance metrics. Subclasses implement :meth:`compute`, which accepts a bags dictionary (as returned by :class:`smx.predicates.bagging.PredicateBagger`) and returns a dictionary mapping bag name → DataFrame with columns ``['Predicate', <MetricName>]``. """
[docs] @abstractmethod def compute(self, bags_dict: Dict[str, Dict[str, pd.DataFrame]]) -> Dict[str, pd.DataFrame]: """Compute metric for every predicate in every bag. Parameters ---------- bags_dict : dict ``{'Bag_1': {rule: DataFrame(['Zone_Sum', 'Predicted_Y', ...]), ...}, ...}`` Returns ------- dict[str, pd.DataFrame] ``{'Bag_1': DataFrame(['Predicate', MetricName]), ...}`` Each inner DataFrame is sorted descending by the metric column. """
# --------------------------------------------------------------------------- # Covariance / Mutual Information metric # ---------------------------------------------------------------------------
[docs] class CovarianceMetric(BasePredicateMetric): """Association metric between zone scores and model predictions. Supports two association measures: * ``'covariance'`` — absolute covariance between zone aggregation values and continuous model predictions (linear dependency). * ``'mutual_info'`` — mutual information (captures non-linear dependencies, requires ``scikit-learn``). Parameters ---------- metric : {'covariance', 'mutual_info'}, default 'covariance' Association measure to compute. threshold : float, default 0.01 Predicates with metric value ≤ threshold are excluded from the result. n_neighbors : int, default 10 Number of nearest neighbours for mutual information estimation. Ignored when ``metric='covariance'``. """ _METRIC_COL_NAMES = { "covariance": "Covariance", "mutual_info": "Mutual_Info", } def __init__( self, metric: Literal["covariance", "mutual_info"] = "covariance", threshold: float = 0.01, n_neighbors: int = 10, ) -> None: if metric not in self._METRIC_COL_NAMES: raise ValueError( f"metric must be one of {list(self._METRIC_COL_NAMES)}. Got '{metric}'." ) self.metric = metric self.threshold = threshold self.n_neighbors = n_neighbors @property def metric_column(self) -> str: return self._METRIC_COL_NAMES[self.metric]
[docs] def compute(self, bags_dict: Dict[str, Dict[str, pd.DataFrame]]) -> Dict[str, pd.DataFrame]: """Compute the association metric for each predicate in each bag. Parameters ---------- bags_dict : dict Bags as returned by :class:`smx.predicates.bagging.PredicateBagger`. Returns ------- dict[str, pd.DataFrame] Keys = bag names. Each DataFrame has columns ``['Predicate', 'Covariance']`` (or ``'Mutual_Info'``), sorted descending by the metric, filtered by *threshold*. """ if self.metric == "mutual_info": from sklearn.feature_selection import mutual_info_regression results: Dict[str, pd.DataFrame] = {} for bag_name, predicates_dict in bags_dict.items(): if not predicates_dict: continue metrics: Dict[str, float] = {} for rule, df_info in predicates_dict.items(): X_zone = df_info["Zone_Sum"].values.reshape(-1, 1) y_pred = df_info["Predicted_Y"].values if len(X_zone) < 2: metrics[rule] = 0.0 continue if self.metric == "covariance": cov_mat = np.cov(X_zone.flatten(), y_pred) metrics[rule] = float(np.abs(cov_mat[0, 1])) else: # mutual_info mi = mutual_info_regression( X_zone, y_pred, discrete_features=False, n_neighbors=self.n_neighbors, random_state=42, ) metrics[rule] = float(mi[0]) metrics_df = ( pd.DataFrame.from_dict(metrics, orient="index", columns=[self.metric_column]) .rename_axis(None) .reset_index() .rename(columns={"index": "Predicate"}) ) metrics_df = metrics_df.sort_values(self.metric_column, ascending=False).reset_index(drop=True) metrics_df = metrics_df[metrics_df[self.metric_column] > self.threshold].reset_index(drop=True) results[bag_name] = metrics_df return results
# --------------------------------------------------------------------------- # Perturbation metric # --------------------------------------------------------------------------- def _get_zone_columns( predicate_rule: str, predicates_df: pd.DataFrame, spectral_cuts: List[Tuple], dataset_columns: pd.Index, ) -> Tuple[List[str], Optional[float], Optional[float]]: """Return (zone_cols, zone_start, zone_end) for a predicate rule. Handles 2-element ``(start, end)``, 3-element ``(name, start, end)``, and 4-element ``(name, start, end, group)`` cuts. When the zone name matches a *group*, all cuts belonging to that group are collected and their column ranges unioned. """ mask = predicates_df["rule"] == predicate_rule if not mask.any(): raise KeyError(f"Predicate '{predicate_rule}' not found in predicates_df.") zone_name = predicates_df.loc[mask, "zone"].values[0] col_numeric = pd.to_numeric(dataset_columns.astype(str), errors="coerce") # ── Collect column ranges that belong to this zone ─────────────────── # A zone_name may come from: # (a) a direct 2- or 3-element cut whose name matches, OR # (b) a 4-element cut whose *group* field matches (grouped zone) col_mask = np.zeros(len(dataset_columns), dtype=bool) found = False for cut in spectral_cuts: if isinstance(cut, dict): cut_name = cut.get("name", f"{cut.get('start')}-{cut.get('end')}") cut_start = cut.get("start") cut_end = cut.get("end") cut_group = cut.get("group", None) elif isinstance(cut, (list, tuple)): if len(cut) == 2: cut_start, cut_end = cut cut_name = f"{cut_start}-{cut_end}" cut_group = None elif len(cut) == 3: cut_name, cut_start, cut_end = cut cut_group = None elif len(cut) == 4: cut_name, cut_start, cut_end, cut_group = cut else: continue else: continue # Direct name match (ungrouped cut) or group match (grouped cut) if cut_name == zone_name or cut_group == zone_name: try: s, e = float(cut_start), float(cut_end) except Exception: continue if s > e: s, e = e, s range_mask = (~np.isnan(col_numeric)) & (col_numeric >= s) & (col_numeric <= e) col_mask |= range_mask found = True if not found: return [], None, None zone_cols = list(dataset_columns[col_mask]) # For zone_start/zone_end return the overall span (used for logging only) matched_vals = col_numeric[col_mask] zone_start = float(matched_vals.min()) if len(matched_vals) else None zone_end = float(matched_vals.max()) if len(matched_vals) else None return zone_cols, zone_start, zone_end
[docs] class PerturbationMetric(BasePredicateMetric): """Spectral-perturbation importance metric. For each predicate, the spectral zone is temporarily replaced by a fixed value (or a per-column statistic) and the change in model prediction is measured. Parameters ---------- estimator : sklearn-compatible estimator Trained model with a ``predict()`` method. Xcalclass_prep : pd.DataFrame Pre-processed calibration dataset (samples × features). predicates_df : pd.DataFrame Predicate catalogue with columns ``'rule'`` and ``'zone'``. spectral_cuts : list of (name, start, end) tuples Defines every spectral zone boundary. perturbation_value : float, default 0 Constant replacement value when ``perturbation_mode='constant'``. perturbation_mode : {'constant', 'mean', 'median', 'min', 'max'}, default 'constant' How to replace the zone. stats_source : {'full', 'predicate'}, default 'full' Data source for computing per-column statistics. metric : str, default 'mean_abs_diff' Importance metric. See :class:`smx.predicates.metrics.PerturbationMetric` docstring for available options per *aim*. normalize_by_zone_size : bool, default False Divide raw importance by the number of zone features (raised to *zone_size_exponent*) to compensate for wide-zone bias. zone_size_exponent : float, default 1.0 Exponent applied to zone size for normalisation. verbose : bool, default False Print per-predicate progress. save_detailed_results : bool, default True Attach a ``'__detailed_perturbation_results__'`` key to the result. """ _REGRESSION_METRICS = {"mean_abs_diff", "mean_diff", "mean_relative_dev"} _CLASSIFICATION_METRICS = { "prediction_change_rate", "accuracy_drop", "f1_drop", "probability_shift", "decision_function_shift", } def __init__( self, estimator: Any, Xcalclass_prep: pd.DataFrame, predicates_df: pd.DataFrame, spectral_cuts: List[Tuple[str, float, float]], perturbation_value: float = 0, perturbation_mode: Literal["constant", "mean", "median", "min", "max"] = "constant", stats_source: Literal["full", "predicate"] = "full", metric: str = "mean_abs_diff", normalize_by_zone_size: bool = False, zone_size_exponent: float = 1.0, verbose: bool = False, save_detailed_results: bool = True, ) -> None: aim = "classification" if metric in self._CLASSIFICATION_METRICS else "regression" if metric in self._REGRESSION_METRICS else None if aim is None: raise ValueError(f"Invalid metric '{metric}'. Must be one of {self._REGRESSION_METRICS} or {self._CLASSIFICATION_METRICS}.") if metric == "probability_shift" and not hasattr(estimator, "predict_proba"): raise ValueError( "'probability_shift' requires an estimator with predict_proba(). " "For SVC, use SVC(probability=True)." ) if metric == "decision_function_shift" and not hasattr(estimator, "decision_function"): raise ValueError( "'decision_function_shift' requires an estimator with decision_function()." ) self.estimator = estimator self.Xcalclass_prep = Xcalclass_prep self.predicates_df = predicates_df self.spectral_cuts = spectral_cuts self.aim = aim self.perturbation_value = perturbation_value self.perturbation_mode = perturbation_mode self.stats_source = stats_source self.metric = metric self.normalize_by_zone_size = normalize_by_zone_size self.zone_size_exponent = zone_size_exponent self.verbose = verbose self.save_detailed_results = save_detailed_results @property def metric_column(self) -> str: return "Perturbation"
[docs] def compute(self, bags_dict: Dict[str, Dict[str, pd.DataFrame]]) -> Dict[str, pd.DataFrame]: """Compute perturbation importance for each predicate in each bag. Parameters ---------- bags_dict : dict Bags as returned by :class:`smx.predicates.bagging.PredicateBagger`. Returns ------- dict[str, pd.DataFrame] Keys = bag names. Each DataFrame has columns ``['Predicate', 'Perturbation']``, sorted descending. When ``save_detailed_results=True`` the key ``'__detailed_perturbation_results__'`` is also included. """ from sklearn.metrics import accuracy_score, f1_score results: Dict[str, pd.DataFrame] = {} detailed_results: Dict[str, Dict] = {} for fold_name, predicates_dict in bags_dict.items(): if not predicates_dict: results[fold_name] = pd.DataFrame({"Predicate": [], "Perturbation": []}) continue fold_metrics: Dict[str, float] = {} fold_detailed: Dict[str, Dict] = {} for pred_rule, df_info in predicates_dict.items(): sample_indices = df_info["Sample_Index"].values.tolist() n_samples = len(sample_indices) if n_samples == 0: fold_metrics[pred_rule] = 0.0 fold_detailed[pred_rule] = {"importance": 0.0, "skip_reason": "n_samples=0"} continue # ── Get zone columns ────────────────────────────────────── try: zone_cols, zone_start, zone_end = _get_zone_columns( pred_rule, self.predicates_df, self.spectral_cuts, self.Xcalclass_prep.columns, ) except (KeyError, ValueError) as exc: fold_metrics[pred_rule] = 0.0 fold_detailed[pred_rule] = {"importance": 0.0, "skip_reason": str(exc)} continue if not zone_cols: fold_metrics[pred_rule] = 0.0 fold_detailed[pred_rule] = {"importance": 0.0, "skip_reason": "empty zone"} continue X_eval = self.Xcalclass_prep.iloc[sample_indices].copy() # ── Perturb zone ────────────────────────────────────────── X_perturbed = X_eval.copy() if self.perturbation_mode == "constant": X_perturbed[zone_cols] = self.perturbation_value else: src = ( self.Xcalclass_prep[zone_cols] if self.stats_source == "full" else X_eval[zone_cols] ) col_stats = getattr(src, self.perturbation_mode)(axis=0) for col in zone_cols: X_perturbed[col] = col_stats[col] # ── Compute importance ──────────────────────────────────── try: importance, importance_for_ranking = self._compute_importance( X_eval, X_perturbed, accuracy_score, f1_score ) except (TypeError, ValueError) as exc: try: y_sample = np.array(self.estimator.predict(X_eval.iloc[:1])).flatten() pred_dtype = y_sample.dtype is_numeric = np.issubdtype(pred_dtype, np.number) except Exception: pred_dtype = "unknown" is_numeric = None if self.aim == "regression" and is_numeric is False: hint = ( f"Metric '{self.metric}' requires numeric predictions, but the " f"estimator returned dtype '{pred_dtype}' (e.g. class labels). " f"Switch to a classification metric such as 'prediction_change_rate'." ) elif self.aim == "classification" and is_numeric is True: hint = ( f"Metric '{self.metric}' is a classification metric, but the " f"estimator appears to return numeric values (dtype '{pred_dtype}'). " f"Switch to a regression metric such as 'mean_abs_diff'." ) else: hint = ( f"Metric '{self.metric}' is incompatible with this estimator " f"(prediction dtype: '{pred_dtype}'). " f"Original error: {exc}" ) raise TypeError(hint) from exc # ── Zone-size normalisation ─────────────────────────────── n_zone_features = len(zone_cols) if self.normalize_by_zone_size and n_zone_features > 0: importance_for_ranking /= n_zone_features ** self.zone_size_exponent fold_metrics[pred_rule] = float(importance_for_ranking) fold_detailed[pred_rule] = { "importance": float(importance), "importance_normalized": float(importance_for_ranking), "n_samples": n_samples, "n_zone_features": n_zone_features, } if self.verbose: print(f" {pred_rule} (n={n_samples}): {importance_for_ranking:.6f}") metrics_df = ( pd.DataFrame.from_dict(fold_metrics, orient="index", columns=["Perturbation"]) .rename_axis(None) .reset_index() .rename(columns={"index": "Predicate"}) .sort_values("Perturbation", ascending=False) .reset_index(drop=True) ) results[fold_name] = metrics_df detailed_results[fold_name] = fold_detailed if self.save_detailed_results: rows = [ {"fold": fold, "predicate": rule, **data} for fold, fold_data in detailed_results.items() for rule, data in fold_data.items() ] results["__detailed_perturbation_results__"] = pd.DataFrame(rows) return results
# ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _compute_importance( self, X_eval: pd.DataFrame, X_perturbed: pd.DataFrame, accuracy_score, f1_score, ) -> Tuple[float, float]: """Return (raw_importance, importance_for_ranking).""" if self.aim == "regression": y_orig = np.array(self.estimator.predict(X_eval)).flatten() y_pert = np.array(self.estimator.predict(X_perturbed)).flatten() if self.metric == "mean_abs_diff": imp = float(np.mean(np.abs(y_orig - y_pert))) return imp, imp elif self.metric == "mean_diff": imp = float(np.mean(y_orig - y_pert)) return imp, float(np.abs(imp)) else: # mean_relative_dev y_safe = np.where(y_orig == 0, np.nan, y_orig) rel = (y_pert - y_orig) / y_safe imp = float(np.nanmean(rel)) return imp, float(np.abs(imp)) # ── Classification ──────────────────────────────────────────────── if self.metric == "prediction_change_rate": y_orig = np.array(self.estimator.predict(X_eval)).flatten() y_pert = np.array(self.estimator.predict(X_perturbed)).flatten() imp = float(np.mean(y_orig != y_pert)) return imp, imp elif self.metric == "accuracy_drop": y_orig = np.array(self.estimator.predict(X_eval)).flatten() y_pert = np.array(self.estimator.predict(X_perturbed)).flatten() imp = float(1.0 - accuracy_score(y_orig, y_pert)) return imp, imp elif self.metric == "f1_drop": y_orig = np.array(self.estimator.predict(X_eval)).flatten() y_pert = np.array(self.estimator.predict(X_perturbed)).flatten() imp = float(1.0 - f1_score(y_orig, y_pert, average="weighted")) return imp, imp elif self.metric == "probability_shift": prob_orig = self.estimator.predict_proba(X_eval) prob_pert = self.estimator.predict_proba(X_perturbed) shift = np.mean(np.sum(np.abs(prob_orig - prob_pert), axis=1) / 2.0) return float(shift), float(shift) else: # decision_function_shift df_orig = np.array(self.estimator.decision_function(X_eval)) df_pert = np.array(self.estimator.decision_function(X_perturbed)) if df_orig.ndim == 1: df_orig = df_orig.flatten() df_pert = df_pert.flatten() imp = float(np.mean(np.abs(df_orig - df_pert))) return imp, imp