Source code for smx.plotting.summary

"""
Summary and diagnostic plots for SMX explanation results.

Functions
---------
plot_lrc_bar
    Horizontal bar chart of LRC scores per zone.
plot_predicate_heatmap
    Zone × predicate heatmap of LRC scores.
plot_zone_scores
    Split-violin of PC1 scores per zone by class.
plot_all_thresholds_overlay
    Full-spectrum overlay of top-predicate threshold per zone.
plot_faithfulness_curve
    Progressive masking curve with shaded AUC and summary annotations.
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union

import numpy as np
import pandas as pd

from smx.plotting.theme import DEFAULT_THEME, SMXTheme, blend_with_white, build_blended_colorscale


# ── Shared helpers ─────────────────────────────────────────────────────────────

def _write_figure(fig, output_path: Optional[Union[str, Path]], width: int, height: int) -> None:
    if output_path is None:
        return

    output_path = Path(output_path)
    suffix = output_path.suffix.lower()
    if suffix == ".html":
        fig.write_html(str(output_path))
    elif suffix in {".png", ".svg", ".pdf", ".jpg", ".jpeg", ".webp"}:
        try:
            fig.write_image(str(output_path), width=width, height=height)
        except ValueError as exc:
            raise ImportError(
                "Static image export requires kaleido. "
                "Install it with: pip install kaleido"
            ) from exc
    else:
        raise ValueError(
            f"Unsupported output format '{suffix}'. "
            "Use '.html' for interactive or '.png'/'.svg'/'.pdf' for static image."
        )


def _require_plotly():
    try:
        import plotly.graph_objects as go
        return go
    except ImportError as exc:
        raise ImportError(
            "plotly is required for SMX plotting. "
            "Install it with: pip install plotly"
        ) from exc


# ── 1. LRC Bar Chart ───────────────────────────────────────────────────────────

[docs] def plot_lrc_bar( zone_ranking_df: pd.DataFrame, output_path: Optional[Union[str, Path]], *, title: Optional[str] = None, colorscale: Optional[str] = None, theme: Optional[SMXTheme] = None, width: int = 800, height: int = 500, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Horizontal bar chart of LRC scores per zone. Each bar represents a spectral zone and is colored according to the same LRC-score colorscale used in :func:`plot_zone_ranking_over_spectrum`, making the two plots directly comparable. The figure is always displayed; set *return_df=True* to return the normalized ranking DataFrame. Parameters ---------- zone_ranking_df : pd.DataFrame LRC table (``Zone`` / ``Local_Reaching_Centrality`` columns) or a pre-normalized ``zone`` / ``score`` / ``rank`` DataFrame. output_path : str or Path, optional Destination file. If ``None``, no file is written. title : str, optional Figure title. colorscale : str, optional Plotly colorscale name. Defaults to ``theme.colorscale``. theme : SMXTheme, optional Visual theme. Defaults to :data:`~smx.plotting.theme.DEFAULT_THEME`. width : int, default 800 Figure width in pixels (static export only). height : int, default 500 Figure height in pixels (static export only). return_df : bool, default False If ``True``, return the normalized ranking DataFrame. Returns ------- pd.DataFrame or None Normalized ``zone / score / rank`` DataFrame when *return_df* is True. """ go = _require_plotly() from plotly.colors import sample_colorscale from smx.plotting.zones import _prepare_zone_ranking_df theme = theme or DEFAULT_THEME _colorscale = colorscale or theme.colorscale _opacity = theme.zone_opacity ranking_df = _prepare_zone_ranking_df(zone_ranking_df) ranking_df = ranking_df.sort_values("score", ascending=True) score_total = float(ranking_df["score"].sum()) ranking_df["pct"] = ranking_df["score"] / max(score_total, 1e-9) * 100 score_min = float(ranking_df["score"].min()) score_max = float(ranking_df["score"].max()) def _norm(s: float) -> float: return (s - score_min) / max(score_max - score_min, 1e-9) colors = [ blend_with_white(sample_colorscale(_colorscale, [_norm(s)])[0], _opacity) for s in ranking_df["score"] ] fig = go.Figure(go.Bar( x=ranking_df["pct"], y=["#" + str(int(r)) + " " + z for r, z in zip(ranking_df["rank"], ranking_df["zone"])], orientation="h", marker=dict(color=colors, line=dict(color="#555555", width=1)), text=[f"{p:.1f}%" for p in ranking_df["pct"]], textposition="outside", hovertemplate="Zone: %{y}<br>Share: %{x:.2f}%<br>LRC: %{customdata:.4f}<extra></extra>", customdata=ranking_df["score"].tolist(), )) x_max = float(ranking_df["pct"].max()) fig.update_layout( **theme.plotly_layout( title=title or "LRC Score by Spectral Zone", xaxis=dict(title="LRC Score (% of total)", range=[0, x_max * 1.25]), yaxis=dict(title="Zone"), margin=dict(t=80, r=100, b=60, l=180), ) ) _write_figure(fig, output_path, width, height) fig.show() if return_df: return ranking_df
# ── 2. Predicate Heatmap ───────────────────────────────────────────────────────
[docs] def plot_predicate_heatmap( lrc_natural_df: pd.DataFrame, output_path: Optional[Union[str, Path]], *, title: Optional[str] = None, colorscale: Optional[str] = None, theme: Optional[SMXTheme] = None, width: int = 1000, height: int = 550, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Heatmap of LRC scores across zones and predicate thresholds. Rows are spectral zones (sorted by maximum LRC, highest at top). Columns are predicates within each zone, grouped by operator (``≤`` then ``>``) and sorted by threshold rank within each group. Cell color encodes LRC score on the same colorscale as the bar chart and zone-ranking plot. The figure is always displayed; set *return_df=True* to return the pivot DataFrame. Parameters ---------- lrc_natural_df : pd.DataFrame ``smx.lrc_natural_`` — must contain ``Zone``, ``Operator``, ``Threshold_Natural``, and ``Local_Reaching_Centrality`` columns. output_path : str or Path, optional Destination file. If ``None``, no file is written. title : str, optional Figure title. colorscale : str, optional Plotly colorscale name. Defaults to ``theme.colorscale``. theme : SMXTheme, optional Visual theme. width : int, default 1000 Figure width (static export). height : int, default 550 Figure height (static export). return_df : bool, default False If ``True``, return the pivot DataFrame (zones × predicate labels). Returns ------- pd.DataFrame or None Pivot DataFrame (zones × predicate labels → LRC score) when *return_df* is True. """ go = _require_plotly() theme = theme or DEFAULT_THEME _colorscale = colorscale or theme.colorscale _blended = build_blended_colorscale(_colorscale, theme.zone_opacity) df = lrc_natural_df[lrc_natural_df["Zone"].notna()].copy() df = df.sort_values(["Zone", "Operator", "Threshold_Natural"]) df["thresh_rank"] = df.groupby(["Zone", "Operator"]).cumcount() + 1 op_symbol = {"<=": "≤", ">": ">"} df["predicate_label"] = df["Operator"].map(op_symbol) + " T" + df["thresh_rank"].astype(str) pivot = df.pivot_table( index="Zone", columns="predicate_label", values="Local_Reaching_Centrality", aggfunc="max", ) zone_order = ( df.groupby("Zone")["Local_Reaching_Centrality"] .max() .sort_values(ascending=True) .index.tolist() ) pivot = pivot.reindex(zone_order) le_cols = sorted(c for c in pivot.columns if c.startswith("≤")) gt_cols = sorted(c for c in pivot.columns if c.startswith(">")) pivot = pivot[le_cols + gt_cols] text_vals = [ [f"{v:.3f}" if not np.isnan(v) else "—" for v in row] for row in pivot.values ] score_min = float(df["Local_Reaching_Centrality"].min()) score_max = float(df["Local_Reaching_Centrality"].max()) # Sentinel value for NaN cells: one step below score_min so they map to a # dedicated "no data" color at the bottom of the extended colorscale. _SENTINEL = score_min - (score_max - score_min) * 0.25 _NO_DATA_COLOR = "rgb(220,220,220)" z_filled = np.where(np.isnan(pivot.values), _SENTINEL, pivot.values) # Prepend a "no data" segment [0, no_data_frac) → _NO_DATA_COLOR, then the # blended data colorscale occupies [no_data_frac, 1]. _data_range = score_max - _SENTINEL _no_data_frac = (score_min - _SENTINEL) / max(_data_range, 1e-9) _extended_cs = [[0.0, _NO_DATA_COLOR], [_no_data_frac, _NO_DATA_COLOR]] + [ [_no_data_frac + (1 - _no_data_frac) * pos, color] for pos, color in _blended ] fig = go.Figure(go.Heatmap( z=z_filled.tolist(), x=pivot.columns.tolist(), y=pivot.index.tolist(), colorscale=_extended_cs, zmin=_SENTINEL, zmax=score_max, colorbar=dict( title=dict(text="LRC score", side="right"), thickness=theme.colorbar_thickness, len=theme.colorbar_len, tickmode="array", tickvals=[score_min, score_max], ticktext=[f"{score_min:.3f} (min)", f"{score_max:.3f} (max)"], tickfont=dict(size=10), ), text=text_vals, texttemplate="%{text}", textfont=dict(size=9, family=theme.font_family), hovertemplate="Zone: %{y}<br>Predicate: %{x}<br>LRC: %{text}<extra></extra>", hoverongaps=False, xgap=2, ygap=2, )) # Vertical separator between ≤ and > columns if le_cols and gt_cols: fig.add_vline( x=len(le_cols) - 0.5, line=dict(color="white", width=3), ) fig.add_annotation( x=(len(le_cols) - 1) / 2, y=1.1, xref="x", yref="paper", text="Operator ≤", showarrow=False, font=dict(size=theme.annotation_font_size, family=theme.font_family), ) fig.add_annotation( x=len(le_cols) + (len(gt_cols) - 1) / 2, y=1.1, xref="x", yref="paper", text="Operator >", showarrow=False, font=dict(size=theme.annotation_font_size, family=theme.font_family), ) fig.update_layout( **theme.plotly_layout( title=title or "Predicate LRC Heatmap", xaxis=dict(title="Predicate (operator · threshold rank)", tickangle=-30), yaxis=dict(title="Zone"), margin=dict(t=100, r=120, b=100, l=160), plot_bgcolor="#d8d8d8", ) ) _write_figure(fig, output_path, width, height) fig.show() if return_df: return pivot
# ── 3. Zone PC1 Score Violin ───────────────────────────────────────────────────
[docs] def plot_zone_scores( zones: Union[pd.DataFrame, Dict[str, pd.DataFrame]], y_labels: pd.Series, output_path: Optional[Union[str, Path]], spectral_cuts: Optional[Iterable] = None, *, title: Optional[str] = None, class_colors: Optional[Dict[str, str]] = None, theme: Optional[SMXTheme] = None, width: int = 1200, height: int = 580, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Split-violin plot of PC1 scores per spectral zone, split by class. When exactly two classes are present the violins are mirrored (split). For three or more classes, separate overlapping violins are drawn. The figure is always displayed; set *return_df=True* to return the zone scores DataFrame. Parameters ---------- zones : pd.DataFrame or dict[str, pd.DataFrame] Either the full calibration DataFrame (requires ``spectral_cuts``) or a pre-extracted zone dict such as ``smx.zones_natural_``. y_labels : pd.Series Class labels aligned row-wise with *zones*. output_path : str or Path, optional Destination file. If ``None``, no file is written. spectral_cuts : iterable, optional Zone boundary definitions. Required when *zones* is a DataFrame. title : str, optional Figure title. class_colors : dict[str, str], optional Per-class hex/CSS colors. Defaults to ``theme.class_colors``. theme : SMXTheme, optional Visual theme. width : int, default 1200 Figure width (static export). height : int, default 580 Figure height (static export). return_df : bool, default False If ``True``, return the zone PC1 scores DataFrame. Returns ------- pd.DataFrame or None Zone PC1 score DataFrame (samples × zones) when *return_df* is True. """ go = _require_plotly() from smx.zones.aggregation import ZoneAggregator theme = theme or DEFAULT_THEME _used: List[str] = [] if isinstance(zones, pd.DataFrame): if spectral_cuts is None: raise ValueError("spectral_cuts is required when zones is a DataFrame.") from smx.zones.extraction import extract_spectral_zones zone_dict = extract_spectral_zones(zones, spectral_cuts) else: zone_dict = zones agg = ZoneAggregator(method="pca") agg.fit(zone_dict) zone_scores_df = agg.transform(zone_dict) zone_cols = zone_scores_df.columns.tolist() classes = list(y_labels.unique()) split_mode = len(classes) == 2 sides = {classes[0]: "negative", classes[1]: "positive"} if split_mode else {} fig = go.Figure() for cls in classes: mask = (y_labels == cls).values color = (class_colors or {}).get(str(cls)) or theme.resolve_class_color(str(cls), _used) for zone in zone_cols: fig.add_trace(go.Violin( x=[zone] * int(mask.sum()), y=zone_scores_df.loc[mask, zone].values, name=f"Class {cls}", legendgroup=str(cls), showlegend=(zone == zone_cols[0]), side=sides.get(cls, "both"), line_color=color, fillcolor=color, opacity=0.85, box_visible=False, meanline_visible=True, points=False, width=0.6, )) fig.update_layout( **theme.plotly_layout( title=title or "PC1 Scores by Spectral Zone and Class", xaxis=dict(title="Spectral Zone", tickangle=-30), yaxis=dict(title="PC 1 Score"), violingap=0.05, violingroupgap=0.0, legend=dict(orientation="h", y=-0.33, x=0.85, xanchor="center"), margin=dict(t=80, r=40, b=140, l=80), ) ) _write_figure(fig, output_path, width, height) fig.show() if return_df: return zone_scores_df
# ── 4. All-Zone Threshold Overlay ──────────────────────────────────────────────
[docs] def plot_all_thresholds_overlay( lrc_natural_df: pd.DataFrame, zones_natural: Dict[str, pd.DataFrame], pca_info_natural: Dict, y_labels: pd.Series, spectral_cuts: Iterable, output_path: Optional[Union[str, Path]], *, title: Optional[str] = None, class_colors: Optional[Dict[str, str]] = None, theme: Optional[SMXTheme] = None, width: int = 1200, height: int = 500, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Full-spectrum overlay of the top-ranked threshold per zone. Mean class spectra are drawn as solid lines across the full spectral axis. The top-ranked predicate threshold for each zone is reconstructed from PCA space and overlaid as a dashed line within its zone's x-range. Threshold line colors follow the LRC-score colorscale so the most influential zones stand out visually. The figure is always displayed; set *return_df=True* to return the top-predicate-per-zone DataFrame. Parameters ---------- lrc_natural_df : pd.DataFrame ``smx.lrc_natural_``. zones_natural : dict[str, pd.DataFrame] ``smx.zones_natural_``. pca_info_natural : dict ``smx.pca_info_natural_``. y_labels : pd.Series Class labels aligned row-wise with the calibration data. spectral_cuts : iterable Zone boundary definitions. output_path : str or Path, optional Destination file. If ``None``, no file is written. title : str, optional Figure title. class_colors : dict[str, str], optional Per-class hex/CSS colors. theme : SMXTheme, optional Visual theme. width : int, default 1200 Figure width (static export). height : int, default 500 Figure height (static export). return_df : bool, default False If ``True``, return the top-predicate-per-zone DataFrame. Returns ------- pd.DataFrame or None Top-predicate-per-zone DataFrame when *return_df* is True. """ go = _require_plotly() from plotly.colors import sample_colorscale from smx.graph.interpretation import reconstruct_threshold_to_spectrum theme = theme or DEFAULT_THEME _used: List[str] = [] top_per_zone = ( lrc_natural_df[lrc_natural_df["Zone"].notna()] .sort_values("Local_Reaching_Centrality", ascending=False) .drop_duplicates(subset=["Zone"]) .reset_index(drop=True) ) score_min = float(top_per_zone["Local_Reaching_Centrality"].min()) score_max = float(top_per_zone["Local_Reaching_Centrality"].max()) def _lrc_color(score: float) -> str: norm = (score - score_min) / max(score_max - score_min, 1e-9) return sample_colorscale(theme.colorscale, [norm])[0] # Parse spectral cuts to get zone names and boundaries cut_rows = [] for cut in spectral_cuts: if isinstance(cut, dict): cut_rows.append((str(cut["name"]), float(cut["start"]), float(cut["end"]))) elif len(cut) == 3: cut_rows.append((str(cut[0]), float(cut[1]), float(cut[2]))) else: cut_rows.append((f"{cut[0]}-{cut[1]}", float(cut[0]), float(cut[1]))) fig = go.Figure() # ── Per-class min / max bands per zone (shaded) ─────────────────────── for cls in y_labels.unique(): mask = (y_labels == cls).values color = (class_colors or {}).get(str(cls)) or theme.resolve_class_color(str(cls), _used) # Build upper (max) and lower (min) envelopes across all zones upper_parts, lower_parts = [], [] for zone_name, _, _ in cut_rows: zone_df = zones_natural.get(zone_name) if zone_df is None or zone_df.empty: continue zone_cls = zone_df[mask] zone_max = zone_cls.max(axis=0) zone_min = zone_cls.min(axis=0) zone_max.index = pd.to_numeric(zone_max.index.astype(str), errors="coerce") zone_min.index = pd.to_numeric(zone_min.index.astype(str), errors="coerce") upper_parts.append(zone_max) lower_parts.append(zone_min) if upper_parts: full_upper = pd.concat(upper_parts).sort_index().dropna() full_lower = pd.concat(lower_parts).sort_index().dropna() x_envelope = full_upper.index.to_numpy(dtype=float) y_upper = full_upper.to_numpy(dtype=float) y_lower = full_lower.to_numpy(dtype=float) # Shaded band between min and max for this class fig.add_trace(go.Scatter( x=np.concatenate([x_envelope, x_envelope[::-1]]), y=np.concatenate([y_upper, y_lower[::-1]]), fill="toself", fillcolor=color, line=dict(color="rgba(0,0,0,0)"), opacity=0.18, name=f"Class {cls} range", showlegend=(not any(c for c in (class_colors or {}) if True)), hovertemplate="Class {cls}<br>Zone: %{{x:.1f}}<br>Min: %{{y:.3f}}<extra></extra>", customdata=[f"Class {cls}"] * len(x_envelope), )) # ── Mean class spectra (full spectrum, solid) ────────────────────────── for cls in y_labels.unique(): mask = (y_labels == cls).values parts = [] for zone_name, _, _ in cut_rows: zone_df = zones_natural.get(zone_name) if zone_df is None or zone_df.empty: continue zone_mean = zone_df[mask].mean(axis=0) zone_mean.index = pd.to_numeric(zone_mean.index.astype(str), errors="coerce") parts.append(zone_mean) if not parts: continue full_mean = pd.concat(parts).sort_index().dropna() color = (class_colors or {}).get(str(cls)) or theme.resolve_class_color(str(cls), _used) fig.add_trace(go.Scatter( x=full_mean.index.to_numpy(dtype=float), y=full_mean.to_numpy(dtype=float), mode="lines", line=dict(color=color, width=theme.class_line_width), name=f"Class {cls} mean", )) # ── Per-zone threshold spectra (dashed, LRC-colored) ────────────────── for _, row in top_per_zone.iterrows(): zone_name = str(row["Zone"]) threshold_score = float(row["Threshold_Natural"]) lrc_score = float(row["Local_Reaching_Centrality"]) threshold_spectrum = reconstruct_threshold_to_spectrum( threshold_value=threshold_score, zone_name=zone_name, pca_info_dict=pca_info_natural, ) threshold_spectrum.index = pd.to_numeric( threshold_spectrum.index.astype(str), errors="coerce" ) threshold_spectrum = threshold_spectrum.dropna().sort_index() t_color = _lrc_color(lrc_score) fig.add_trace(go.Scatter( x=threshold_spectrum.index.to_numpy(dtype=float), y=threshold_spectrum.to_numpy(dtype=float), mode="lines", line=dict( color=t_color, width=theme.threshold_line_width, dash=theme.threshold_line_dash, ), name=f"Threshold: {zone_name} (LRC {lrc_score:.3f})", )) # ── Zone boundary vlines ─────────────────────────────────────────────── boundaries = sorted({start for _, start, _ in cut_rows} | {end for _, _, end in cut_rows}) for b in boundaries: fig.add_vline(x=b, line=dict( color=theme.zone_boundary_color, width=theme.zone_boundary_width, dash=theme.zone_boundary_dash, )) fig.update_layout( **theme.plotly_layout( title=title or "All-Zone Threshold Overlay", xaxis_title="Spectral variables", yaxis_title="Intensity", legend=dict(orientation="h", y=-0.28, x=0.5, xanchor="center"), margin=dict(t=80, r=40, b=150, l=80), ) ) _write_figure(fig, output_path, width, height) fig.show() if return_df: return top_per_zone
# ── 5. Faithfulness Curve ─────────────────────────────────────────────────────
[docs] def plot_faithfulness_curve( faithfulness_result: Dict, output_path: Optional[Union[str, Path]], *, title: Optional[str] = None, theme: Optional[SMXTheme] = None, width: int = 1100, height: int = 560, show_percentile: bool = False, show_faithfulness_level: bool = True, show_summary: bool = True, show_level_intervals: bool = True, return_df: bool = False, ) -> Union[None, pd.DataFrame]: """Plot the progressive masking faithfulness curve and its AUC. The figure is always displayed; set *return_df=True* to return the masking-curve DataFrame. Parameters ---------- faithfulness_result : dict Output of :meth:`smx.pipeline.SMX.evaluate_faithfulness`. Must contain ``curve_df`` and is expected to include ``auc``, ``level``, ``null_percentile``, and related summary fields. output_path : str or Path, optional Destination file. If ``None``, no file is written. title : str, optional Figure title. theme : SMXTheme, optional Visual theme. width : int, default 1100 Figure width (static export). height : int, default 560 Figure height (static export). show_percentile : bool, default False Whether to include the random-baseline percentile in the summary box. The value remains available in ``faithfulness_result`` either way. show_faithfulness_level : bool, default True Whether to show the "Faithfulness Level" annotation box. show_summary : bool, default True Whether to show the summary box with AUC and metric info. show_level_intervals : bool, default True Whether to show the "Level intervals" legend box. return_df : bool, default False If ``True``, return the masking-curve DataFrame. Returns ------- pd.DataFrame or None Masking-curve DataFrame when *return_df* is True. """ go = _require_plotly() from plotly.colors import sample_colorscale theme = theme or DEFAULT_THEME if not isinstance(faithfulness_result, dict): raise TypeError("faithfulness_result must be a dictionary.") if "curve_df" not in faithfulness_result: raise ValueError("faithfulness_result must contain 'curve_df'.") curve_df = faithfulness_result["curve_df"] if not isinstance(curve_df, pd.DataFrame) or curve_df.empty: raise ValueError("faithfulness_result['curve_df'] must be a non-empty DataFrame.") required = {"k", "score"} missing = required.difference(curve_df.columns) if missing: raise ValueError( "faithfulness_result['curve_df'] is missing required columns: " + ", ".join(sorted(missing)) ) curve_df = curve_df.copy().sort_values("k").reset_index(drop=True) curve_with_zero = pd.concat([ pd.DataFrame([{ "k": 0, "masked_zone": "none", "masked_zones": tuple(), "score": 0.0, }]), curve_df, ], ignore_index=True) auc = faithfulness_result.get("auc") auc_normalized = faithfulness_result.get("auc_normalized") level = faithfulness_result.get("level") percentile = faithfulness_result.get("null_percentile") metric = faithfulness_result.get("metric") n_masked_zones = faithfulness_result.get("n_masked_zones", len(curve_df)) score_min = float(curve_df["score"].min()) score_max = float(curve_df["score"].max()) _colorscale = theme.colorscale _blended_colorscale = build_blended_colorscale(_colorscale, theme.zone_opacity) def _score_color(score: float) -> str: norm = (score - score_min) / max(score_max - score_min, 1e-9) return sample_colorscale(_colorscale, [norm])[0] def _score_color_blended(score: float) -> str: norm = (score - score_min) / max(score_max - score_min, 1e-9) return sample_colorscale(_blended_colorscale, [norm])[0] line_color = _score_color(score_max) fill_color = _score_color_blended(score_max) fig = go.Figure() fig.add_trace(go.Scatter( x=curve_with_zero["k"], y=curve_with_zero["score"], mode="lines", name="Faithfulness curve", line=dict(color=line_color, width=theme.threshold_line_width + 1), fill="tozeroy", fillcolor=fill_color.replace("rgb(", "rgba(").replace(")", f", {theme.zone_opacity})"), hoverinfo="skip", )) fig.add_trace(go.Scatter( x=curve_df["k"], y=curve_df["score"], mode="markers+text", name="Prediction shift", marker=dict( size=10, color=curve_df["score"], coloraxis="coloraxis", line=dict(color="white", width=1.5), ), text=curve_df["masked_zone"] if "masked_zone" in curve_df.columns else None, textposition="top center", textfont=dict(size=theme.annotation_font_size, family=theme.font_family), customdata=np.stack([ curve_df["masked_zone"].astype(str).to_numpy() if "masked_zone" in curve_df.columns else np.repeat("", len(curve_df)), ], axis=-1), hovertemplate=( "k masked zones: %{x}<br>" "Prediction shift: %{y:.6f}<br>" "Latest masked zone: %{customdata[0]}<extra></extra>" ), )) for k in curve_df["k"].tolist(): fig.add_vline( x=float(k), line=dict( color=theme.zone_boundary_color, width=theme.zone_boundary_width, dash=theme.zone_boundary_dash, ), layer="below", ) if show_faithfulness_level: fig.add_annotation( x=0.15, y=0.85, xref="paper", yref="paper", xanchor="right", yanchor="top", align="left", showarrow=False, bordercolor="rgba(140,140,140,0.35)", borderwidth=1, borderpad=8, bgcolor="rgba(255,255,255,0.88)", font=dict(size=17, family=theme.font_family), text=f"Faithfulness Level: <br><b>{level}</b>", ) if show_summary: summary_lines = [] if auc is not None: summary_lines.append(f"AUC: {float(auc):.4f}") if auc_normalized is not None: summary_lines.append(f"Normalized AUC: {float(auc_normalized):.4f}") if show_percentile and percentile is not None: summary_lines.append(f"Percentile: {float(percentile):.1f}%") if metric is not None: summary_lines.append(f"Metric: {metric}") fig.add_annotation( x=0.995, y=0.85, xref="paper", yref="paper", xanchor="right", yanchor="top", align="left", showarrow=False, bordercolor="rgba(140,140,140,0.35)", borderwidth=1, borderpad=8, bgcolor="rgba(255,255,255,0.88)", text="<br>".join(summary_lines), ) if show_level_intervals: fig.add_annotation( x=0.995, y=0.15, xref="paper", yref="paper", xanchor="right", yanchor="top", align="left", showarrow=False, bordercolor="rgba(140,140,140,0.35)", borderwidth=1, borderpad=8, bgcolor="rgba(255,255,255,0.88)", text=( "<b>Level intervals</b><br>" "Low: percentile &lt; 60<br>" "Moderate: 60-79.9<br>" "High: 80-94.9<br>" "Very high: >= 95" ), ) fig.update_layout( **theme.plotly_layout( title=title or "Faithfulness via Progressive Zone Masking", xaxis=dict( title="k masked zones (cumulative top-ranked masking)", tickmode="linear", tick0=0, dtick=1, range=[-0.2, max(float(n_masked_zones), float(curve_df["k"].max())) + 0.4], ), yaxis=dict( title="Prediction shift score", rangemode="tozero", ), coloraxis=dict( colorscale=_blended_colorscale, cmin=score_min, cmax=score_max, showscale=False, ), legend=dict(orientation="h", y=-0.25, x=0.5, xanchor="center"), margin=dict(t=90, r=40, b=100, l=80), ) ) _write_figure(fig, output_path, width, height) fig.show() if return_df: return curve_df