123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- from __future__ import annotations
-
- import math
- from collections import Counter
- from typing import TYPE_CHECKING, Dict, List, Tuple
-
- from .base_metric import BaseMetric
-
- if TYPE_CHECKING:
- from data.kg_dataset import KGDataset
-
-
- def _entropy_from_counter(counter: Counter[int], n: int) -> float:
- """Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts."""
- ent = 0.0
- for count in counter.values():
- if count: # avoid log(0)
- p = count / n
- ent -= p * math.log(p)
- return ent
-
-
- def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
- """Return (C_ratio, C_NWLEC) given *per-iteration* entropies."""
-
- if n <= 1:
- # Degenerate graph.
- return 0.0, 0.0
-
- H_plus_1 = len(entropies)
- # Eq. (1): simple ratio normalisation.
- c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1
-
- # Eq. (2): effective-colour NWLEC.
- k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies]
- c_nwlec = sum(k_terms) / H_plus_1
- return c_ratio, c_nwlec
-
-
- class WLEC(BaseMetric):
- """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""
-
- def __init__(self, dataset: KGDataset) -> None:
- super().__init__(dataset)
-
- def compute(
- self,
- H: int = 3,
- *,
- return_full: bool = False,
- progress: bool = True,
- ) -> float | Tuple[float, List[float]]:
- """Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC).
-
- Parameters
- ----------
- dataset:
- Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``),
- ``num_entities`` and ``num_relations``.
- H:
- Number of refinement iterations **after** the initial colouring (so the
- colours are updated **H** times and the entropy is measured **H+1** times).
- return_full:
- If *True*, additionally return the list of entropies for each depth.
- progress:
- Whether to print a mini progress bar (no external dependency).
-
- Returns
- -------
- average_entropy or (average_entropy, entropies)
- Average of the stored entropies, and optionally the list itself.
- """
- n = int(self.num_entities)
- if n == 0:
- msg = "Dataset appears to contain zero entities."
- raise ValueError(msg)
-
- # Colour assignment - we keep it as a *list* of ints for cheap hashing.
- # Start with colour 1 for every entity (index 0 unused).
- colour: List[int] = [1] * n
-
- entropies: List[float] = []
-
- # Optional poor-man's progress bar.
- def _print_prog(i: int) -> None:
- if progress:
- print(f"\rIteration {i}/{H}", end="", flush=True)
-
- for h in range(H + 1):
- _print_prog(h)
-
- # ------------------------------------------------------------------
- # Step 1: measure entropy of current colouring ----------------------
- # ------------------------------------------------------------------
- freq = Counter(colour)
- ent = _entropy_from_counter(freq, n)
- entropies.append(ent)
-
- if h == H:
- break # We have reached the requested depth - stop here.
-
- # ------------------------------------------------------------------
- # Step 2: create refined colours -----------------------------------
- # ------------------------------------------------------------------
- bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
- next_colour: List[int] = [0] * n
-
- for v in range(n):
- # Build the multiset T containing outgoing and incoming features.
- T: List[Tuple[int, int, int]] = []
- for r, t in self.out_adj[v]:
- T.append((r, 0, colour[t])) # 0 = outgoing (\u2193)
- for r, h_ in self.in_adj[v]:
- T.append((r, 1, colour[h_])) # 1 = incoming (\u2191)
- T.sort() # canonical ordering
-
- key = (colour[v], tuple(T))
- if key not in bucket:
- bucket[key] = len(bucket) + 1 # new colour ID (start at 1)
- next_colour[v] = bucket[key]
-
- colour = next_colour # move to next iteration
-
- _print_prog(H)
- if progress:
- print() # newline after progress bar
-
- average_entropy = sum(entropies) / len(entropies)
-
- c_ratio, c_nwlec = _normalise_scores(entropies, n)
-
- if return_full:
- return average_entropy, entropies, c_ratio, c_nwlec
-
- return average_entropy, c_ratio, c_nwlec
|