Source code for smx.graph.builder

"""
PredicateGraphBuilder: construct a directed predicate graph.

Edges are added between consecutive predicates ordered by a ranking metric
inside each bag, and weights are accumulated across bags.  Bidirectional
edges are resolved by keeping only the higher-weight direction.
"""

from typing import Dict, List, Optional

import numpy as np
import networkx as nx
import pandas as pd

from smx.graph.interpretation import _extract_zone_from_predicate


[docs] class PredicateGraphBuilder: """Build a directed predicate graph from bags and metric rankings. Edge weights derive from the ranking metric of the *source* predicate in each bag. When the same directed edge appears in multiple bags, the weights are accumulated (summed). Bidirectional edges (A→B and B→A) are resolved by: * keeping the edge with the higher accumulated weight; * breaking ties randomly. Parameters ---------- random_state : int, default 42 Seed for random tie-breaking of bidirectional edges. show_details : bool, default True Print details about identified and removed bidirectional edges. var_exp : bool, default False When ``True``, multiply edge weights by the PC1 explained variance of the source predicate's spectral zone. Requires *pca_info_dict*. pca_info_dict : dict, optional ``{zone_name: {'variance_explained': float, ...}}`` as returned by :class:`smx.zones.aggregation.ZoneAggregator` (``pca_info_`` attribute). Required when ``var_exp=True``. """ def __init__( self, random_state: int = 42, show_details: bool = True, var_exp: bool = False, pca_info_dict: Optional[Dict] = None, ) -> None: if var_exp and pca_info_dict is None: raise ValueError("pca_info_dict is required when var_exp=True.") self.random_state = random_state self.show_details = show_details self.var_exp = var_exp self.pca_info_dict = pca_info_dict
[docs] def build( self, bags_result: Dict[str, Dict[str, pd.DataFrame]], predicate_ranking_dict: Dict[str, pd.DataFrame], metric_column: str = "Covariance", ) -> nx.DiGraph: """Build and return the directed predicate graph. Parameters ---------- bags_result : dict Bags as returned by :class:`smx.predicates.bagging.PredicateBagger`. predicate_ranking_dict : dict ``{bag_name: DataFrame(['Predicate', metric_column])}`` as returned by a :class:`smx.predicates.metrics.BasePredicateMetric` subclass. metric_column : str, default ``'Covariance'`` Name of the metric column in *predicate_ranking_dict*. Returns ------- nx.DiGraph Directed graph with ``'weight'`` edge attributes. """ np.random.seed(self.random_state) DG: nx.DiGraph = nx.DiGraph() DG.add_node("Class_A", node_type="terminal", class_label="A") DG.add_node("Class_B", node_type="terminal", class_label="B") # ── Phase 1: accumulate edge weights ───────────────────────────── for bag_name, bag_predicates_dict in bags_result.items(): if bag_name not in predicate_ranking_dict: continue ranking_df: pd.DataFrame = predicate_ranking_dict[bag_name] if ranking_df.empty: continue ordered = [ p for p in ranking_df["Predicate"].tolist() if p in bag_predicates_dict ] if not ordered: continue lookup: Dict[str, float] = dict( zip(ranking_df["Predicate"], ranking_df[metric_column]) ) for i in range(len(ordered) - 1): src = ordered[i] dst = ordered[i + 1] DG.add_node(src, node_type="predicate") DG.add_node(dst, node_type="predicate") w = self._edge_weight(src, lookup) self._accumulate(DG, src, dst, w, bag_name) # Last predicate → terminal last = ordered[-1] DG.add_node(last, node_type="predicate") df_last = bag_predicates_dict[last] if "Class_Predicted" in df_last.columns: majority = df_last["Class_Predicted"].value_counts().idxmax() terminal = f"Class_{majority}" else: terminal = "Class_A" w = self._edge_weight(last, lookup) self._accumulate(DG, last, terminal, w, bag_name) # ── Phase 2: resolve bidirectional edges ───────────────────────── n_removed = self._resolve_bidirectional(DG) print( f"\n{'='*70}\n" f"CONSTRUCTED GRAPH SUMMARY\n" f"{'='*70}\n" f"Edges (after removing {n_removed} bidirectional): {DG.number_of_edges()}\n" f"Predicate nodes: " f"{sum(1 for _, a in DG.nodes(data=True) if a.get('node_type') == 'predicate')}\n" f"Metric: {metric_column}\n" f"Variance-exp weighting: {'ENABLED' if self.var_exp else 'DISABLED'}" ) return DG
# ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _edge_weight(self, predicate: str, lookup: Dict[str, float]) -> float: w = float(lookup.get(predicate, 0.0)) if self.var_exp and self.pca_info_dict is not None: try: zone = _extract_zone_from_predicate(predicate) if zone in self.pca_info_dict: w *= self.pca_info_dict[zone]["variance_explained"] except ValueError: pass return w @staticmethod def _accumulate(DG: nx.DiGraph, src: str, dst: str, w: float, bag: str) -> None: if DG.has_edge(src, dst): DG[src][dst]["weight"] += w else: DG.add_edge(src, dst, weight=w, bag=bag) def _resolve_bidirectional(self, DG: nx.DiGraph) -> int: """Remove the weaker direction of every conflicting edge pair.""" pairs = [] processed = set() for u, v in list(DG.edges()): if DG.has_edge(v, u) and (v, u) not in processed: pairs.append((u, v, float(DG[u][v]["weight"]), float(DG[v][u]["weight"]))) processed.add((u, v)) processed.add((v, u)) print(f"\nTotal bidirectional pairs found: {len(pairs)}") n_removed = 0 for u, v, w_fwd, w_rev in pairs: if not (DG.has_edge(u, v) and DG.has_edge(v, u)): continue # already resolved if w_fwd > w_rev: DG.remove_edge(v, u) if self.show_details: print(f"Removed: {v}{u} ({w_rev:.4f}) | Kept: {u}{v} ({w_fwd:.4f})") elif w_rev > w_fwd: DG.remove_edge(u, v) if self.show_details: print(f"Removed: {u}{v} ({w_fwd:.4f}) | Kept: {v}{u} ({w_rev:.4f})") else: if np.random.rand() > 0.5: DG.remove_edge(v, u) if self.show_details: print(f"Tie! Removed: {v}{u}") else: DG.remove_edge(u, v) if self.show_details: print(f"Tie! Removed: {u}{v}") n_removed += 1 return n_removed