"""
Summary and diagnostic plots for SMX explanation results.
Functions
---------
plot_lrc_bar
Horizontal bar chart of LRC scores per zone.
plot_predicate_heatmap
Zone × predicate heatmap of LRC scores.
plot_zone_scores
Split-violin of PC1 scores per zone by class.
plot_all_thresholds_overlay
Full-spectrum overlay of top-predicate threshold per zone.
plot_faithfulness_curve
Progressive masking curve with shaded AUC and summary annotations.
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
import numpy as np
import pandas as pd
from smx.plotting.theme import DEFAULT_THEME, SMXTheme, blend_with_white, build_blended_colorscale
# ── Shared helpers ─────────────────────────────────────────────────────────────
def _write_figure(fig, output_path: Optional[Union[str, Path]], width: int, height: int) -> None:
if output_path is None:
return
output_path = Path(output_path)
suffix = output_path.suffix.lower()
if suffix == ".html":
fig.write_html(str(output_path))
elif suffix in {".png", ".svg", ".pdf", ".jpg", ".jpeg", ".webp"}:
try:
fig.write_image(str(output_path), width=width, height=height)
except ValueError as exc:
raise ImportError(
"Static image export requires kaleido. "
"Install it with: pip install kaleido"
) from exc
else:
raise ValueError(
f"Unsupported output format '{suffix}'. "
"Use '.html' for interactive or '.png'/'.svg'/'.pdf' for static image."
)
def _require_plotly():
try:
import plotly.graph_objects as go
return go
except ImportError as exc:
raise ImportError(
"plotly is required for SMX plotting. "
"Install it with: pip install plotly"
) from exc
# ── 1. LRC Bar Chart ───────────────────────────────────────────────────────────
[docs]
def plot_lrc_bar(
zone_ranking_df: pd.DataFrame,
output_path: Optional[Union[str, Path]],
*,
title: Optional[str] = None,
colorscale: Optional[str] = None,
theme: Optional[SMXTheme] = None,
width: int = 800,
height: int = 500,
return_df: bool = False,
) -> Union[None, pd.DataFrame]:
"""Horizontal bar chart of LRC scores per zone.
Each bar represents a spectral zone and is colored according to the same
LRC-score colorscale used in :func:`plot_zone_ranking_over_spectrum`,
making the two plots directly comparable. The figure is always displayed;
set *return_df=True* to return the normalized ranking DataFrame.
Parameters
----------
zone_ranking_df : pd.DataFrame
LRC table (``Zone`` / ``Local_Reaching_Centrality`` columns) or a
pre-normalized ``zone`` / ``score`` / ``rank`` DataFrame.
output_path : str or Path, optional
Destination file. If ``None``, no file is written.
title : str, optional
Figure title.
colorscale : str, optional
Plotly colorscale name. Defaults to ``theme.colorscale``.
theme : SMXTheme, optional
Visual theme. Defaults to :data:`~smx.plotting.theme.DEFAULT_THEME`.
width : int, default 800
Figure width in pixels (static export only).
height : int, default 500
Figure height in pixels (static export only).
return_df : bool, default False
If ``True``, return the normalized ranking DataFrame.
Returns
-------
pd.DataFrame or None
Normalized ``zone / score / rank`` DataFrame when *return_df* is True.
"""
go = _require_plotly()
from plotly.colors import sample_colorscale
from smx.plotting.zones import _prepare_zone_ranking_df
theme = theme or DEFAULT_THEME
_colorscale = colorscale or theme.colorscale
_opacity = theme.zone_opacity
ranking_df = _prepare_zone_ranking_df(zone_ranking_df)
ranking_df = ranking_df.sort_values("score", ascending=True)
score_total = float(ranking_df["score"].sum())
ranking_df["pct"] = ranking_df["score"] / max(score_total, 1e-9) * 100
score_min = float(ranking_df["score"].min())
score_max = float(ranking_df["score"].max())
def _norm(s: float) -> float:
return (s - score_min) / max(score_max - score_min, 1e-9)
colors = [
blend_with_white(sample_colorscale(_colorscale, [_norm(s)])[0], _opacity)
for s in ranking_df["score"]
]
fig = go.Figure(go.Bar(
x=ranking_df["pct"],
y=["#" + str(int(r)) + " " + z for r, z in zip(ranking_df["rank"], ranking_df["zone"])],
orientation="h",
marker=dict(color=colors, line=dict(color="#555555", width=1)),
text=[f"{p:.1f}%" for p in ranking_df["pct"]],
textposition="outside",
hovertemplate="Zone: %{y}<br>Share: %{x:.2f}%<br>LRC: %{customdata:.4f}<extra></extra>",
customdata=ranking_df["score"].tolist(),
))
x_max = float(ranking_df["pct"].max())
fig.update_layout(
**theme.plotly_layout(
title=title or "LRC Score by Spectral Zone",
xaxis=dict(title="LRC Score (% of total)", range=[0, x_max * 1.25]),
yaxis=dict(title="Zone"),
margin=dict(t=80, r=100, b=60, l=180),
)
)
_write_figure(fig, output_path, width, height)
fig.show()
if return_df:
return ranking_df
# ── 2. Predicate Heatmap ───────────────────────────────────────────────────────
[docs]
def plot_predicate_heatmap(
lrc_natural_df: pd.DataFrame,
output_path: Optional[Union[str, Path]],
*,
title: Optional[str] = None,
colorscale: Optional[str] = None,
theme: Optional[SMXTheme] = None,
width: int = 1000,
height: int = 550,
return_df: bool = False,
) -> Union[None, pd.DataFrame]:
"""Heatmap of LRC scores across zones and predicate thresholds.
Rows are spectral zones (sorted by maximum LRC, highest at top). Columns
are predicates within each zone, grouped by operator (``≤`` then ``>``)
and sorted by threshold rank within each group. Cell color encodes LRC
score on the same colorscale as the bar chart and zone-ranking plot.
The figure is always displayed; set *return_df=True* to return the pivot
DataFrame.
Parameters
----------
lrc_natural_df : pd.DataFrame
``smx.lrc_natural_`` — must contain ``Zone``, ``Operator``,
``Threshold_Natural``, and ``Local_Reaching_Centrality`` columns.
output_path : str or Path, optional
Destination file. If ``None``, no file is written.
title : str, optional
Figure title.
colorscale : str, optional
Plotly colorscale name. Defaults to ``theme.colorscale``.
theme : SMXTheme, optional
Visual theme.
width : int, default 1000
Figure width (static export).
height : int, default 550
Figure height (static export).
return_df : bool, default False
If ``True``, return the pivot DataFrame (zones × predicate labels).
Returns
-------
pd.DataFrame or None
Pivot DataFrame (zones × predicate labels → LRC score) when
*return_df* is True.
"""
go = _require_plotly()
theme = theme or DEFAULT_THEME
_colorscale = colorscale or theme.colorscale
_blended = build_blended_colorscale(_colorscale, theme.zone_opacity)
df = lrc_natural_df[lrc_natural_df["Zone"].notna()].copy()
df = df.sort_values(["Zone", "Operator", "Threshold_Natural"])
df["thresh_rank"] = df.groupby(["Zone", "Operator"]).cumcount() + 1
op_symbol = {"<=": "≤", ">": ">"}
df["predicate_label"] = df["Operator"].map(op_symbol) + " T" + df["thresh_rank"].astype(str)
pivot = df.pivot_table(
index="Zone",
columns="predicate_label",
values="Local_Reaching_Centrality",
aggfunc="max",
)
zone_order = (
df.groupby("Zone")["Local_Reaching_Centrality"]
.max()
.sort_values(ascending=True)
.index.tolist()
)
pivot = pivot.reindex(zone_order)
le_cols = sorted(c for c in pivot.columns if c.startswith("≤"))
gt_cols = sorted(c for c in pivot.columns if c.startswith(">"))
pivot = pivot[le_cols + gt_cols]
text_vals = [
[f"{v:.3f}" if not np.isnan(v) else "—" for v in row]
for row in pivot.values
]
score_min = float(df["Local_Reaching_Centrality"].min())
score_max = float(df["Local_Reaching_Centrality"].max())
# Sentinel value for NaN cells: one step below score_min so they map to a
# dedicated "no data" color at the bottom of the extended colorscale.
_SENTINEL = score_min - (score_max - score_min) * 0.25
_NO_DATA_COLOR = "rgb(220,220,220)"
z_filled = np.where(np.isnan(pivot.values), _SENTINEL, pivot.values)
# Prepend a "no data" segment [0, no_data_frac) → _NO_DATA_COLOR, then the
# blended data colorscale occupies [no_data_frac, 1].
_data_range = score_max - _SENTINEL
_no_data_frac = (score_min - _SENTINEL) / max(_data_range, 1e-9)
_extended_cs = [[0.0, _NO_DATA_COLOR], [_no_data_frac, _NO_DATA_COLOR]] + [
[_no_data_frac + (1 - _no_data_frac) * pos, color]
for pos, color in _blended
]
fig = go.Figure(go.Heatmap(
z=z_filled.tolist(),
x=pivot.columns.tolist(),
y=pivot.index.tolist(),
colorscale=_extended_cs,
zmin=_SENTINEL,
zmax=score_max,
colorbar=dict(
title=dict(text="LRC score", side="right"),
thickness=theme.colorbar_thickness,
len=theme.colorbar_len,
tickmode="array",
tickvals=[score_min, score_max],
ticktext=[f"{score_min:.3f} (min)", f"{score_max:.3f} (max)"],
tickfont=dict(size=10),
),
text=text_vals,
texttemplate="%{text}",
textfont=dict(size=9, family=theme.font_family),
hovertemplate="Zone: %{y}<br>Predicate: %{x}<br>LRC: %{text}<extra></extra>",
hoverongaps=False,
xgap=2,
ygap=2,
))
# Vertical separator between ≤ and > columns
if le_cols and gt_cols:
fig.add_vline(
x=len(le_cols) - 0.5,
line=dict(color="white", width=3),
)
fig.add_annotation(
x=(len(le_cols) - 1) / 2,
y=1.1,
xref="x",
yref="paper",
text="Operator ≤",
showarrow=False,
font=dict(size=theme.annotation_font_size, family=theme.font_family),
)
fig.add_annotation(
x=len(le_cols) + (len(gt_cols) - 1) / 2,
y=1.1,
xref="x",
yref="paper",
text="Operator >",
showarrow=False,
font=dict(size=theme.annotation_font_size, family=theme.font_family),
)
fig.update_layout(
**theme.plotly_layout(
title=title or "Predicate LRC Heatmap",
xaxis=dict(title="Predicate (operator · threshold rank)", tickangle=-30),
yaxis=dict(title="Zone"),
margin=dict(t=100, r=120, b=100, l=160),
plot_bgcolor="#d8d8d8",
)
)
_write_figure(fig, output_path, width, height)
fig.show()
if return_df:
return pivot
# ── 3. Zone PC1 Score Violin ───────────────────────────────────────────────────
[docs]
def plot_zone_scores(
zones: Union[pd.DataFrame, Dict[str, pd.DataFrame]],
y_labels: pd.Series,
output_path: Optional[Union[str, Path]],
spectral_cuts: Optional[Iterable] = None,
*,
title: Optional[str] = None,
class_colors: Optional[Dict[str, str]] = None,
theme: Optional[SMXTheme] = None,
width: int = 1200,
height: int = 580,
return_df: bool = False,
) -> Union[None, pd.DataFrame]:
"""Split-violin plot of PC1 scores per spectral zone, split by class.
When exactly two classes are present the violins are mirrored (split).
For three or more classes, separate overlapping violins are drawn.
The figure is always displayed; set *return_df=True* to return the zone
scores DataFrame.
Parameters
----------
zones : pd.DataFrame or dict[str, pd.DataFrame]
Either the full calibration DataFrame (requires ``spectral_cuts``) or
a pre-extracted zone dict such as ``smx.zones_natural_``.
y_labels : pd.Series
Class labels aligned row-wise with *zones*.
output_path : str or Path, optional
Destination file. If ``None``, no file is written.
spectral_cuts : iterable, optional
Zone boundary definitions. Required when *zones* is a DataFrame.
title : str, optional
Figure title.
class_colors : dict[str, str], optional
Per-class hex/CSS colors. Defaults to ``theme.class_colors``.
theme : SMXTheme, optional
Visual theme.
width : int, default 1200
Figure width (static export).
height : int, default 580
Figure height (static export).
return_df : bool, default False
If ``True``, return the zone PC1 scores DataFrame.
Returns
-------
pd.DataFrame or None
Zone PC1 score DataFrame (samples × zones) when *return_df* is True.
"""
go = _require_plotly()
from smx.zones.aggregation import ZoneAggregator
theme = theme or DEFAULT_THEME
_used: List[str] = []
if isinstance(zones, pd.DataFrame):
if spectral_cuts is None:
raise ValueError("spectral_cuts is required when zones is a DataFrame.")
from smx.zones.extraction import extract_spectral_zones
zone_dict = extract_spectral_zones(zones, spectral_cuts)
else:
zone_dict = zones
agg = ZoneAggregator(method="pca")
agg.fit(zone_dict)
zone_scores_df = agg.transform(zone_dict)
zone_cols = zone_scores_df.columns.tolist()
classes = list(y_labels.unique())
split_mode = len(classes) == 2
sides = {classes[0]: "negative", classes[1]: "positive"} if split_mode else {}
fig = go.Figure()
for cls in classes:
mask = (y_labels == cls).values
color = (class_colors or {}).get(str(cls)) or theme.resolve_class_color(str(cls), _used)
for zone in zone_cols:
fig.add_trace(go.Violin(
x=[zone] * int(mask.sum()),
y=zone_scores_df.loc[mask, zone].values,
name=f"Class {cls}",
legendgroup=str(cls),
showlegend=(zone == zone_cols[0]),
side=sides.get(cls, "both"),
line_color=color,
fillcolor=color,
opacity=0.85,
box_visible=False,
meanline_visible=True,
points=False,
width=0.6,
))
fig.update_layout(
**theme.plotly_layout(
title=title or "PC1 Scores by Spectral Zone and Class",
xaxis=dict(title="Spectral Zone", tickangle=-30),
yaxis=dict(title="PC 1 Score"),
violingap=0.05,
violingroupgap=0.0,
legend=dict(orientation="h", y=-0.33, x=0.85, xanchor="center"),
margin=dict(t=80, r=40, b=140, l=80),
)
)
_write_figure(fig, output_path, width, height)
fig.show()
if return_df:
return zone_scores_df
# ── 4. All-Zone Threshold Overlay ──────────────────────────────────────────────
[docs]
def plot_all_thresholds_overlay(
lrc_natural_df: pd.DataFrame,
zones_natural: Dict[str, pd.DataFrame],
pca_info_natural: Dict,
y_labels: pd.Series,
spectral_cuts: Iterable,
output_path: Optional[Union[str, Path]],
*,
title: Optional[str] = None,
class_colors: Optional[Dict[str, str]] = None,
theme: Optional[SMXTheme] = None,
width: int = 1200,
height: int = 500,
return_df: bool = False,
) -> Union[None, pd.DataFrame]:
"""Full-spectrum overlay of the top-ranked threshold per zone.
Mean class spectra are drawn as solid lines across the full spectral axis.
The top-ranked predicate threshold for each zone is reconstructed from PCA
space and overlaid as a dashed line within its zone's x-range. Threshold
line colors follow the LRC-score colorscale so the most influential zones
stand out visually. The figure is always displayed; set *return_df=True*
to return the top-predicate-per-zone DataFrame.
Parameters
----------
lrc_natural_df : pd.DataFrame
``smx.lrc_natural_``.
zones_natural : dict[str, pd.DataFrame]
``smx.zones_natural_``.
pca_info_natural : dict
``smx.pca_info_natural_``.
y_labels : pd.Series
Class labels aligned row-wise with the calibration data.
spectral_cuts : iterable
Zone boundary definitions.
output_path : str or Path, optional
Destination file. If ``None``, no file is written.
title : str, optional
Figure title.
class_colors : dict[str, str], optional
Per-class hex/CSS colors.
theme : SMXTheme, optional
Visual theme.
width : int, default 1200
Figure width (static export).
height : int, default 500
Figure height (static export).
return_df : bool, default False
If ``True``, return the top-predicate-per-zone DataFrame.
Returns
-------
pd.DataFrame or None
Top-predicate-per-zone DataFrame when *return_df* is True.
"""
go = _require_plotly()
from plotly.colors import sample_colorscale
from smx.graph.interpretation import reconstruct_threshold_to_spectrum
theme = theme or DEFAULT_THEME
_used: List[str] = []
top_per_zone = (
lrc_natural_df[lrc_natural_df["Zone"].notna()]
.sort_values("Local_Reaching_Centrality", ascending=False)
.drop_duplicates(subset=["Zone"])
.reset_index(drop=True)
)
score_min = float(top_per_zone["Local_Reaching_Centrality"].min())
score_max = float(top_per_zone["Local_Reaching_Centrality"].max())
def _lrc_color(score: float) -> str:
norm = (score - score_min) / max(score_max - score_min, 1e-9)
return sample_colorscale(theme.colorscale, [norm])[0]
# Parse spectral cuts to get zone names and boundaries
cut_rows = []
for cut in spectral_cuts:
if isinstance(cut, dict):
cut_rows.append((str(cut["name"]), float(cut["start"]), float(cut["end"])))
elif len(cut) == 3:
cut_rows.append((str(cut[0]), float(cut[1]), float(cut[2])))
else:
cut_rows.append((f"{cut[0]}-{cut[1]}", float(cut[0]), float(cut[1])))
fig = go.Figure()
# ── Per-class min / max bands per zone (shaded) ───────────────────────
for cls in y_labels.unique():
mask = (y_labels == cls).values
color = (class_colors or {}).get(str(cls)) or theme.resolve_class_color(str(cls), _used)
# Build upper (max) and lower (min) envelopes across all zones
upper_parts, lower_parts = [], []
for zone_name, _, _ in cut_rows:
zone_df = zones_natural.get(zone_name)
if zone_df is None or zone_df.empty:
continue
zone_cls = zone_df[mask]
zone_max = zone_cls.max(axis=0)
zone_min = zone_cls.min(axis=0)
zone_max.index = pd.to_numeric(zone_max.index.astype(str), errors="coerce")
zone_min.index = pd.to_numeric(zone_min.index.astype(str), errors="coerce")
upper_parts.append(zone_max)
lower_parts.append(zone_min)
if upper_parts:
full_upper = pd.concat(upper_parts).sort_index().dropna()
full_lower = pd.concat(lower_parts).sort_index().dropna()
x_envelope = full_upper.index.to_numpy(dtype=float)
y_upper = full_upper.to_numpy(dtype=float)
y_lower = full_lower.to_numpy(dtype=float)
# Shaded band between min and max for this class
fig.add_trace(go.Scatter(
x=np.concatenate([x_envelope, x_envelope[::-1]]),
y=np.concatenate([y_upper, y_lower[::-1]]),
fill="toself",
fillcolor=color,
line=dict(color="rgba(0,0,0,0)"),
opacity=0.18,
name=f"Class {cls} range",
showlegend=(not any(c for c in (class_colors or {}) if True)),
hovertemplate="Class {cls}<br>Zone: %{{x:.1f}}<br>Min: %{{y:.3f}}<extra></extra>",
customdata=[f"Class {cls}"] * len(x_envelope),
))
# ── Mean class spectra (full spectrum, solid) ──────────────────────────
for cls in y_labels.unique():
mask = (y_labels == cls).values
parts = []
for zone_name, _, _ in cut_rows:
zone_df = zones_natural.get(zone_name)
if zone_df is None or zone_df.empty:
continue
zone_mean = zone_df[mask].mean(axis=0)
zone_mean.index = pd.to_numeric(zone_mean.index.astype(str), errors="coerce")
parts.append(zone_mean)
if not parts:
continue
full_mean = pd.concat(parts).sort_index().dropna()
color = (class_colors or {}).get(str(cls)) or theme.resolve_class_color(str(cls), _used)
fig.add_trace(go.Scatter(
x=full_mean.index.to_numpy(dtype=float),
y=full_mean.to_numpy(dtype=float),
mode="lines",
line=dict(color=color, width=theme.class_line_width),
name=f"Class {cls} mean",
))
# ── Per-zone threshold spectra (dashed, LRC-colored) ──────────────────
for _, row in top_per_zone.iterrows():
zone_name = str(row["Zone"])
threshold_score = float(row["Threshold_Natural"])
lrc_score = float(row["Local_Reaching_Centrality"])
threshold_spectrum = reconstruct_threshold_to_spectrum(
threshold_value=threshold_score,
zone_name=zone_name,
pca_info_dict=pca_info_natural,
)
threshold_spectrum.index = pd.to_numeric(
threshold_spectrum.index.astype(str), errors="coerce"
)
threshold_spectrum = threshold_spectrum.dropna().sort_index()
t_color = _lrc_color(lrc_score)
fig.add_trace(go.Scatter(
x=threshold_spectrum.index.to_numpy(dtype=float),
y=threshold_spectrum.to_numpy(dtype=float),
mode="lines",
line=dict(
color=t_color,
width=theme.threshold_line_width,
dash=theme.threshold_line_dash,
),
name=f"Threshold: {zone_name} (LRC {lrc_score:.3f})",
))
# ── Zone boundary vlines ───────────────────────────────────────────────
boundaries = sorted({start for _, start, _ in cut_rows} | {end for _, _, end in cut_rows})
for b in boundaries:
fig.add_vline(x=b, line=dict(
color=theme.zone_boundary_color,
width=theme.zone_boundary_width,
dash=theme.zone_boundary_dash,
))
fig.update_layout(
**theme.plotly_layout(
title=title or "All-Zone Threshold Overlay",
xaxis_title="Spectral variables",
yaxis_title="Intensity",
legend=dict(orientation="h", y=-0.28, x=0.5, xanchor="center"),
margin=dict(t=80, r=40, b=150, l=80),
)
)
_write_figure(fig, output_path, width, height)
fig.show()
if return_df:
return top_per_zone
# ── 5. Faithfulness Curve ─────────────────────────────────────────────────────
[docs]
def plot_faithfulness_curve(
faithfulness_result: Dict,
output_path: Optional[Union[str, Path]],
*,
title: Optional[str] = None,
theme: Optional[SMXTheme] = None,
width: int = 1100,
height: int = 560,
show_percentile: bool = False,
show_faithfulness_level: bool = True,
show_summary: bool = True,
show_level_intervals: bool = True,
return_df: bool = False,
) -> Union[None, pd.DataFrame]:
"""Plot the progressive masking faithfulness curve and its AUC.
The figure is always displayed; set *return_df=True* to return the
masking-curve DataFrame.
Parameters
----------
faithfulness_result : dict
Output of :meth:`smx.pipeline.SMX.evaluate_faithfulness`.
Must contain ``curve_df`` and is expected to include ``auc``,
``level``, ``null_percentile``, and related summary fields.
output_path : str or Path, optional
Destination file. If ``None``, no file is written.
title : str, optional
Figure title.
theme : SMXTheme, optional
Visual theme.
width : int, default 1100
Figure width (static export).
height : int, default 560
Figure height (static export).
show_percentile : bool, default False
Whether to include the random-baseline percentile in the summary box.
The value remains available in ``faithfulness_result`` either way.
show_faithfulness_level : bool, default True
Whether to show the "Faithfulness Level" annotation box.
show_summary : bool, default True
Whether to show the summary box with AUC and metric info.
show_level_intervals : bool, default True
Whether to show the "Level intervals" legend box.
return_df : bool, default False
If ``True``, return the masking-curve DataFrame.
Returns
-------
pd.DataFrame or None
Masking-curve DataFrame when *return_df* is True.
"""
go = _require_plotly()
from plotly.colors import sample_colorscale
theme = theme or DEFAULT_THEME
if not isinstance(faithfulness_result, dict):
raise TypeError("faithfulness_result must be a dictionary.")
if "curve_df" not in faithfulness_result:
raise ValueError("faithfulness_result must contain 'curve_df'.")
curve_df = faithfulness_result["curve_df"]
if not isinstance(curve_df, pd.DataFrame) or curve_df.empty:
raise ValueError("faithfulness_result['curve_df'] must be a non-empty DataFrame.")
required = {"k", "score"}
missing = required.difference(curve_df.columns)
if missing:
raise ValueError(
"faithfulness_result['curve_df'] is missing required columns: "
+ ", ".join(sorted(missing))
)
curve_df = curve_df.copy().sort_values("k").reset_index(drop=True)
curve_with_zero = pd.concat([
pd.DataFrame([{
"k": 0,
"masked_zone": "none",
"masked_zones": tuple(),
"score": 0.0,
}]),
curve_df,
], ignore_index=True)
auc = faithfulness_result.get("auc")
auc_normalized = faithfulness_result.get("auc_normalized")
level = faithfulness_result.get("level")
percentile = faithfulness_result.get("null_percentile")
metric = faithfulness_result.get("metric")
n_masked_zones = faithfulness_result.get("n_masked_zones", len(curve_df))
score_min = float(curve_df["score"].min())
score_max = float(curve_df["score"].max())
_colorscale = theme.colorscale
_blended_colorscale = build_blended_colorscale(_colorscale, theme.zone_opacity)
def _score_color(score: float) -> str:
norm = (score - score_min) / max(score_max - score_min, 1e-9)
return sample_colorscale(_colorscale, [norm])[0]
def _score_color_blended(score: float) -> str:
norm = (score - score_min) / max(score_max - score_min, 1e-9)
return sample_colorscale(_blended_colorscale, [norm])[0]
line_color = _score_color(score_max)
fill_color = _score_color_blended(score_max)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=curve_with_zero["k"],
y=curve_with_zero["score"],
mode="lines",
name="Faithfulness curve",
line=dict(color=line_color, width=theme.threshold_line_width + 1),
fill="tozeroy",
fillcolor=fill_color.replace("rgb(", "rgba(").replace(")", f", {theme.zone_opacity})"),
hoverinfo="skip",
))
fig.add_trace(go.Scatter(
x=curve_df["k"],
y=curve_df["score"],
mode="markers+text",
name="Prediction shift",
marker=dict(
size=10,
color=curve_df["score"],
coloraxis="coloraxis",
line=dict(color="white", width=1.5),
),
text=curve_df["masked_zone"] if "masked_zone" in curve_df.columns else None,
textposition="top center",
textfont=dict(size=theme.annotation_font_size, family=theme.font_family),
customdata=np.stack([
curve_df["masked_zone"].astype(str).to_numpy()
if "masked_zone" in curve_df.columns
else np.repeat("", len(curve_df)),
], axis=-1),
hovertemplate=(
"k masked zones: %{x}<br>"
"Prediction shift: %{y:.6f}<br>"
"Latest masked zone: %{customdata[0]}<extra></extra>"
),
))
for k in curve_df["k"].tolist():
fig.add_vline(
x=float(k),
line=dict(
color=theme.zone_boundary_color,
width=theme.zone_boundary_width,
dash=theme.zone_boundary_dash,
),
layer="below",
)
if show_faithfulness_level:
fig.add_annotation(
x=0.15,
y=0.85,
xref="paper",
yref="paper",
xanchor="right",
yanchor="top",
align="left",
showarrow=False,
bordercolor="rgba(140,140,140,0.35)",
borderwidth=1,
borderpad=8,
bgcolor="rgba(255,255,255,0.88)",
font=dict(size=17, family=theme.font_family),
text=f"Faithfulness Level: <br><b>{level}</b>",
)
if show_summary:
summary_lines = []
if auc is not None:
summary_lines.append(f"AUC: {float(auc):.4f}")
if auc_normalized is not None:
summary_lines.append(f"Normalized AUC: {float(auc_normalized):.4f}")
if show_percentile and percentile is not None:
summary_lines.append(f"Percentile: {float(percentile):.1f}%")
if metric is not None:
summary_lines.append(f"Metric: {metric}")
fig.add_annotation(
x=0.995,
y=0.85,
xref="paper",
yref="paper",
xanchor="right",
yanchor="top",
align="left",
showarrow=False,
bordercolor="rgba(140,140,140,0.35)",
borderwidth=1,
borderpad=8,
bgcolor="rgba(255,255,255,0.88)",
text="<br>".join(summary_lines),
)
if show_level_intervals:
fig.add_annotation(
x=0.995,
y=0.15,
xref="paper",
yref="paper",
xanchor="right",
yanchor="top",
align="left",
showarrow=False,
bordercolor="rgba(140,140,140,0.35)",
borderwidth=1,
borderpad=8,
bgcolor="rgba(255,255,255,0.88)",
text=(
"<b>Level intervals</b><br>"
"Low: percentile < 60<br>"
"Moderate: 60-79.9<br>"
"High: 80-94.9<br>"
"Very high: >= 95"
),
)
fig.update_layout(
**theme.plotly_layout(
title=title or "Faithfulness via Progressive Zone Masking",
xaxis=dict(
title="k masked zones (cumulative top-ranked masking)",
tickmode="linear",
tick0=0,
dtick=1,
range=[-0.2, max(float(n_masked_zones), float(curve_df["k"].max())) + 0.4],
),
yaxis=dict(
title="Prediction shift score",
rangemode="tozero",
),
coloraxis=dict(
colorscale=_blended_colorscale,
cmin=score_min,
cmax=score_max,
showscale=False,
),
legend=dict(orientation="h", y=-0.25, x=0.5, xanchor="center"),
margin=dict(t=90, r=40, b=100, l=80),
)
)
_write_figure(fig, output_path, width, height)
fig.show()
if return_df:
return curve_df