from __future__ import annotations import math from collections import Counter, defaultdict from typing import TYPE_CHECKING, Dict, List, Tuple import torch from .base_metric import BaseMetric if TYPE_CHECKING: from data.kg_dataset import KGDataset def _entropy_from_counter(counter: Counter[int], n: int) -> float: """Return Shannon entropy (nats) of a colour distribution of length *n*.""" ent = 0.0 for cnt in counter.values(): if cnt: p = cnt / n ent -= p * math.log(p) return ent def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: """Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies.""" if n <= 1: return 0.0, 0.0 H_plus_1 = len(entropies) log_n = math.log(n) c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1 c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1 return c_ratio, c_nwlec def _relation_families( triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9 ) -> torch.Tensor: """Return a LongTensor mapping each relation id → family id. Two relations *r, r_inv* are put into the same family when **at least `thresh` fraction** of the edges labelled *r* have a **single** reverse edge labelled *r_inv*. """ assert 0.0 < thresh <= 1.0 # Build map (h,t) -> list of relations on that edge direction. edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list) for h, r, t in triples.tolist(): edge2rels[(h, t)].append(int(r)) # Count reciprocal co‑occurrences (r, r_rev). pair_counts: Dict[Tuple[int, int], int] = defaultdict(int) rel_totals: List[int] = [0] * num_relations for (h, t), rels in edge2rels.items(): rev_rels = edge2rels.get((t, h)) if not rev_rels: continue for r in rels: rel_totals[r] += 1 for r_rev in rev_rels: pair_counts[(r, r_rev)] += 1 # Proposed inverse mapping r -> r_inv. proposed: List[int | None] = [None] * num_relations for (r, r_rev), cnt in pair_counts.items(): if cnt == 0: continue if cnt / rel_totals[r] >= thresh: # majority of r's edges are reversed by r_rev proposed[r] = r_rev # Build family ids (transitive closure is overkill; data are simple). family = list(range(num_relations)) for r, rinv in enumerate(proposed): if rinv is not None: root = min(family[r], family[rinv]) family[r] = family[rinv] = root return torch.tensor(family, dtype=torch.long) def _conditional_relation_entropy( colours: List[int], triples: torch.Tensor, family_map: torch.Tensor, ) -> float: counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int)) m = int(triples.size(0)) fam = family_map # alias for speed for h, r, t in triples.tolist(): # key = (colours[h], colours[t]) # ordered key (direction‑aware) key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic) counts[key][int(fam[r])] += 1 # ▼ use family id h_cond = 0.0 for rel_counts in counts.values(): total_s = sum(rel_counts.values()) P_s = total_s / m inv_total = 1.0 / total_s for cnt in rel_counts.values(): p_rs = cnt * inv_total h_cond -= P_s * p_rs * math.log(p_rs) return h_cond class WLCREC(BaseMetric): """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" def __init__(self, dataset: KGDataset) -> None: super().__init__(dataset) def compute( self, H: int = 3, cond_h: int = 1, inv_thresh: float = 0.9, return_full: bool = False, progress: bool = True, ) -> float | Tuple[float, List[float]]: """Compute WL-entropy scores and the composite difficulty metric. Parameters ---------- dataset Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``. H WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers. return_full Also return the list of per-layer entropies. progress Print textual progress bar. Returns ------- avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC *Six* scalars (floats). If ``return_full=True`` an extra list of layer entropies is appended. """ # compute relation family map once family_map = _relation_families( self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh ) # ─── WL iterations ──────────────────────────────────────────────────── n = self.dataset.num_entities colour: List[int] = [1] * n # colour 1 for everyone entropies: List[float] = [] colours_per_h: List[List[int]] = [] def _print_prog(i: int) -> None: if progress: print(f"\rWL iteration {i}/{H}", end="", flush=True) for h in range(H + 1): _print_prog(h) colours_per_h.append(colour.copy()) # entropy of current colouring freq = Counter(colour) entropies.append(_entropy_from_counter(freq, n)) if h == H: break # refine colours bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} next_colour: List[int] = [0] * n for v in range(n): T: List[Tuple[int, int, int]] = [] for r, t in self.out_adj[v]: T.append((r, 0, colour[t])) # outgoing for r, h_ in self.in_adj[v]: T.append((r, 1, colour[h_])) # incoming T.sort() key = (colour[v], tuple(T)) if key not in bucket: bucket[key] = len(bucket) + 1 next_colour[v] = bucket[key] colour = next_colour _print_prog(H) if progress: print() # ─── Normalised diversity ──────────────────────────────────────────── avg_entropy = sum(entropies) / len(entropies) c_ratio, c_nwlec = _normalise_scores(entropies, n) # ─── Residual relation entropy ─────────────────────────────────────── h_cond = _conditional_relation_entropy( colours_per_h[cond_h], self.dataset.triples, family_map ) # ─── Composite difficulty - Eq. (1) ────────────────────────────────── d_ratio = c_ratio * h_cond d_nwlec = c_nwlec * h_cond result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec) if return_full: result += (entropies,) return result # -------------------------------------------------------------------- # Weisfeiler–Lehman colours (unchanged) # -------------------------------------------------------------------- def wl_colours(self, depth: int) -> List[List[int]]: triples = self.dataset.triples.tolist() out_adj, in_adj = defaultdict(list), defaultdict(list) for h, r, t in triples: out_adj[h].append((r, t)) in_adj[t].append((r, h)) n_ent = self.dataset.num_entities colours = [[0] * n_ent] # round-0 for h in range(1, depth + 1): prev, nxt, sig2c = colours[-1], [0] * n_ent, {} fresh = 0 for v in range(n_ent): neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ ("↑", r, prev[u]) for r, u in in_adj.get(v, []) ] neigh.sort() sig = (prev[v], tuple(neigh)) if sig not in sig2c: sig2c[sig] = fresh fresh += 1 nxt[v] = sig2c[sig] colours.append(nxt) return colours