from __future__ import annotations import math from collections import Counter from typing import TYPE_CHECKING, Dict, List, Tuple from .base_metric import BaseMetric if TYPE_CHECKING: from data.kg_dataset import KGDataset def _entropy_from_counter(counter: Counter[int], n: int) -> float: """Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts.""" ent = 0.0 for count in counter.values(): if count: # avoid log(0) p = count / n ent -= p * math.log(p) return ent def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: """Return (C_ratio, C_NWLEC) given *per-iteration* entropies.""" if n <= 1: # Degenerate graph. return 0.0, 0.0 H_plus_1 = len(entropies) # Eq. (1): simple ratio normalisation. c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1 # Eq. (2): effective-colour NWLEC. k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies] c_nwlec = sum(k_terms) / H_plus_1 return c_ratio, c_nwlec class WLEC(BaseMetric): """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" def __init__(self, dataset: KGDataset) -> None: super().__init__(dataset) def compute( self, H: int = 3, *, return_full: bool = False, progress: bool = True, ) -> float | Tuple[float, List[float]]: """Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC). Parameters ---------- dataset: Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``), ``num_entities`` and ``num_relations``. H: Number of refinement iterations **after** the initial colouring (so the colours are updated **H** times and the entropy is measured **H+1** times). return_full: If *True*, additionally return the list of entropies for each depth. progress: Whether to print a mini progress bar (no external dependency). Returns ------- average_entropy or (average_entropy, entropies) Average of the stored entropies, and optionally the list itself. """ n = int(self.num_entities) if n == 0: msg = "Dataset appears to contain zero entities." raise ValueError(msg) # Colour assignment - we keep it as a *list* of ints for cheap hashing. # Start with colour 1 for every entity (index 0 unused). colour: List[int] = [1] * n entropies: List[float] = [] # Optional poor-man's progress bar. def _print_prog(i: int) -> None: if progress: print(f"\rIteration {i}/{H}", end="", flush=True) for h in range(H + 1): _print_prog(h) # ------------------------------------------------------------------ # Step 1: measure entropy of current colouring ---------------------- # ------------------------------------------------------------------ freq = Counter(colour) ent = _entropy_from_counter(freq, n) entropies.append(ent) if h == H: break # We have reached the requested depth - stop here. # ------------------------------------------------------------------ # Step 2: create refined colours ----------------------------------- # ------------------------------------------------------------------ bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} next_colour: List[int] = [0] * n for v in range(n): # Build the multiset T containing outgoing and incoming features. T: List[Tuple[int, int, int]] = [] for r, t in self.out_adj[v]: T.append((r, 0, colour[t])) # 0 = outgoing (\u2193) for r, h_ in self.in_adj[v]: T.append((r, 1, colour[h_])) # 1 = incoming (\u2191) T.sort() # canonical ordering key = (colour[v], tuple(T)) if key not in bucket: bucket[key] = len(bucket) + 1 # new colour ID (start at 1) next_colour[v] = bucket[key] colour = next_colour # move to next iteration _print_prog(H) if progress: print() # newline after progress bar average_entropy = sum(entropies) / len(entropies) c_ratio, c_nwlec = _normalise_scores(entropies, n) if return_full: return average_entropy, entropies, c_ratio, c_nwlec return average_entropy, c_ratio, c_nwlec