You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

wlcrec.py 8.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from __future__ import annotations
  2. import math
  3. from collections import Counter, defaultdict
  4. from typing import TYPE_CHECKING, Dict, List, Tuple
  5. import torch
  6. from .base_metric import BaseMetric
  7. if TYPE_CHECKING:
  8. from data.kg_dataset import KGDataset
  9. def _entropy_from_counter(counter: Counter[int], n: int) -> float:
  10. """Return Shannon entropy (nats) of a colour distribution of length *n*."""
  11. ent = 0.0
  12. for cnt in counter.values():
  13. if cnt:
  14. p = cnt / n
  15. ent -= p * math.log(p)
  16. return ent
  17. def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
  18. """Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies."""
  19. if n <= 1:
  20. return 0.0, 0.0
  21. H_plus_1 = len(entropies)
  22. log_n = math.log(n)
  23. c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1
  24. c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1
  25. return c_ratio, c_nwlec
  26. def _relation_families(
  27. triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9
  28. ) -> torch.Tensor:
  29. """Return a LongTensor mapping each relation id → family id.
  30. Two relations *r, r_inv* are put into the same family when **at least
  31. `thresh` fraction** of the edges labelled *r* have a **single** reverse edge
  32. labelled *r_inv*.
  33. """
  34. assert 0.0 < thresh <= 1.0
  35. # Build map (h,t) -> list of relations on that edge direction.
  36. edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list)
  37. for h, r, t in triples.tolist():
  38. edge2rels[(h, t)].append(int(r))
  39. # Count reciprocal co‑occurrences (r, r_rev).
  40. pair_counts: Dict[Tuple[int, int], int] = defaultdict(int)
  41. rel_totals: List[int] = [0] * num_relations
  42. for (h, t), rels in edge2rels.items():
  43. rev_rels = edge2rels.get((t, h))
  44. if not rev_rels:
  45. continue
  46. for r in rels:
  47. rel_totals[r] += 1
  48. for r_rev in rev_rels:
  49. pair_counts[(r, r_rev)] += 1
  50. # Proposed inverse mapping r -> r_inv.
  51. proposed: List[int | None] = [None] * num_relations
  52. for (r, r_rev), cnt in pair_counts.items():
  53. if cnt == 0:
  54. continue
  55. if cnt / rel_totals[r] >= thresh:
  56. # majority of r's edges are reversed by r_rev
  57. proposed[r] = r_rev
  58. # Build family ids (transitive closure is overkill; data are simple).
  59. family = list(range(num_relations))
  60. for r, rinv in enumerate(proposed):
  61. if rinv is not None:
  62. root = min(family[r], family[rinv])
  63. family[r] = family[rinv] = root
  64. return torch.tensor(family, dtype=torch.long)
  65. def _conditional_relation_entropy(
  66. colours: List[int],
  67. triples: torch.Tensor,
  68. family_map: torch.Tensor,
  69. ) -> float:
  70. counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))
  71. m = int(triples.size(0))
  72. fam = family_map # alias for speed
  73. for h, r, t in triples.tolist():
  74. # key = (colours[h], colours[t]) # ordered key (direction‑aware)
  75. key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic)
  76. counts[key][int(fam[r])] += 1 # ▼ use family id
  77. h_cond = 0.0
  78. for rel_counts in counts.values():
  79. total_s = sum(rel_counts.values())
  80. P_s = total_s / m
  81. inv_total = 1.0 / total_s
  82. for cnt in rel_counts.values():
  83. p_rs = cnt * inv_total
  84. h_cond -= P_s * p_rs * math.log(p_rs)
  85. return h_cond
  86. class WLCREC(BaseMetric):
  87. """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""
  88. def __init__(self, dataset: KGDataset) -> None:
  89. super().__init__(dataset)
  90. def compute(
  91. self,
  92. H: int = 3,
  93. cond_h: int = 1,
  94. inv_thresh: float = 0.9,
  95. return_full: bool = False,
  96. progress: bool = True,
  97. ) -> float | Tuple[float, List[float]]:
  98. """Compute WL-entropy scores and the composite difficulty metric.
  99. Parameters
  100. ----------
  101. dataset
  102. Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``.
  103. H
  104. WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers.
  105. return_full
  106. Also return the list of per-layer entropies.
  107. progress
  108. Print textual progress bar.
  109. Returns
  110. -------
  111. avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC
  112. *Six* scalars (floats). If ``return_full=True`` an extra list of
  113. layer entropies is appended.
  114. """
  115. # compute relation family map once
  116. family_map = _relation_families(
  117. self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh
  118. )
  119. # ─── WL iterations ────────────────────────────────────────────────────
  120. n = self.dataset.num_entities
  121. colour: List[int] = [1] * n # colour 1 for everyone
  122. entropies: List[float] = []
  123. colours_per_h: List[List[int]] = []
  124. def _print_prog(i: int) -> None:
  125. if progress:
  126. print(f"\rWL iteration {i}/{H}", end="", flush=True)
  127. for h in range(H + 1):
  128. _print_prog(h)
  129. colours_per_h.append(colour.copy())
  130. # entropy of current colouring
  131. freq = Counter(colour)
  132. entropies.append(_entropy_from_counter(freq, n))
  133. if h == H:
  134. break
  135. # refine colours
  136. bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
  137. next_colour: List[int] = [0] * n
  138. for v in range(n):
  139. T: List[Tuple[int, int, int]] = []
  140. for r, t in self.out_adj[v]:
  141. T.append((r, 0, colour[t])) # outgoing
  142. for r, h_ in self.in_adj[v]:
  143. T.append((r, 1, colour[h_])) # incoming
  144. T.sort()
  145. key = (colour[v], tuple(T))
  146. if key not in bucket:
  147. bucket[key] = len(bucket) + 1
  148. next_colour[v] = bucket[key]
  149. colour = next_colour
  150. _print_prog(H)
  151. if progress:
  152. print()
  153. # ─── Normalised diversity ────────────────────────────────────────────
  154. avg_entropy = sum(entropies) / len(entropies)
  155. c_ratio, c_nwlec = _normalise_scores(entropies, n)
  156. # ─── Residual relation entropy ───────────────────────────────────────
  157. h_cond = _conditional_relation_entropy(
  158. colours_per_h[cond_h], self.dataset.triples, family_map
  159. )
  160. # ─── Composite difficulty - Eq. (1) ──────────────────────────────────
  161. d_ratio = c_ratio * h_cond
  162. d_nwlec = c_nwlec * h_cond
  163. result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec)
  164. if return_full:
  165. result += (entropies,)
  166. return result
  167. # --------------------------------------------------------------------
  168. # Weisfeiler–Lehman colours (unchanged)
  169. # --------------------------------------------------------------------
  170. def wl_colours(self, depth: int) -> List[List[int]]:
  171. triples = self.dataset.triples.tolist()
  172. out_adj, in_adj = defaultdict(list), defaultdict(list)
  173. for h, r, t in triples:
  174. out_adj[h].append((r, t))
  175. in_adj[t].append((r, h))
  176. n_ent = self.dataset.num_entities
  177. colours = [[0] * n_ent] # round-0
  178. for h in range(1, depth + 1):
  179. prev, nxt, sig2c = colours[-1], [0] * n_ent, {}
  180. fresh = 0
  181. for v in range(n_ent):
  182. neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
  183. ("↑", r, prev[u]) for r, u in in_adj.get(v, [])
  184. ]
  185. neigh.sort()
  186. sig = (prev[v], tuple(neigh))
  187. if sig not in sig2c:
  188. sig2c[sig] = fresh
  189. fresh += 1
  190. nxt[v] = sig2c[sig]
  191. colours.append(nxt)
  192. return colours