You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

wlec.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from __future__ import annotations
  2. import math
  3. from collections import Counter
  4. from typing import TYPE_CHECKING, Dict, List, Tuple
  5. from .base_metric import BaseMetric
  6. if TYPE_CHECKING:
  7. from data.kg_dataset import KGDataset
  8. def _entropy_from_counter(counter: Counter[int], n: int) -> float:
  9. """Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts."""
  10. ent = 0.0
  11. for count in counter.values():
  12. if count: # avoid log(0)
  13. p = count / n
  14. ent -= p * math.log(p)
  15. return ent
  16. def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
  17. """Return (C_ratio, C_NWLEC) given *per-iteration* entropies."""
  18. if n <= 1:
  19. # Degenerate graph.
  20. return 0.0, 0.0
  21. H_plus_1 = len(entropies)
  22. # Eq. (1): simple ratio normalisation.
  23. c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1
  24. # Eq. (2): effective-colour NWLEC.
  25. k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies]
  26. c_nwlec = sum(k_terms) / H_plus_1
  27. return c_ratio, c_nwlec
  28. class WLEC(BaseMetric):
  29. """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""
  30. def __init__(self, dataset: KGDataset) -> None:
  31. super().__init__(dataset)
  32. def compute(
  33. self,
  34. H: int = 3,
  35. *,
  36. return_full: bool = False,
  37. progress: bool = True,
  38. ) -> float | Tuple[float, List[float]]:
  39. """Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC).
  40. Parameters
  41. ----------
  42. dataset:
  43. Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``),
  44. ``num_entities`` and ``num_relations``.
  45. H:
  46. Number of refinement iterations **after** the initial colouring (so the
  47. colours are updated **H** times and the entropy is measured **H+1** times).
  48. return_full:
  49. If *True*, additionally return the list of entropies for each depth.
  50. progress:
  51. Whether to print a mini progress bar (no external dependency).
  52. Returns
  53. -------
  54. average_entropy or (average_entropy, entropies)
  55. Average of the stored entropies, and optionally the list itself.
  56. """
  57. n = int(self.num_entities)
  58. if n == 0:
  59. msg = "Dataset appears to contain zero entities."
  60. raise ValueError(msg)
  61. # Colour assignment - we keep it as a *list* of ints for cheap hashing.
  62. # Start with colour 1 for every entity (index 0 unused).
  63. colour: List[int] = [1] * n
  64. entropies: List[float] = []
  65. # Optional poor-man's progress bar.
  66. def _print_prog(i: int) -> None:
  67. if progress:
  68. print(f"\rIteration {i}/{H}", end="", flush=True)
  69. for h in range(H + 1):
  70. _print_prog(h)
  71. # ------------------------------------------------------------------
  72. # Step 1: measure entropy of current colouring ----------------------
  73. # ------------------------------------------------------------------
  74. freq = Counter(colour)
  75. ent = _entropy_from_counter(freq, n)
  76. entropies.append(ent)
  77. if h == H:
  78. break # We have reached the requested depth - stop here.
  79. # ------------------------------------------------------------------
  80. # Step 2: create refined colours -----------------------------------
  81. # ------------------------------------------------------------------
  82. bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
  83. next_colour: List[int] = [0] * n
  84. for v in range(n):
  85. # Build the multiset T containing outgoing and incoming features.
  86. T: List[Tuple[int, int, int]] = []
  87. for r, t in self.out_adj[v]:
  88. T.append((r, 0, colour[t])) # 0 = outgoing (\u2193)
  89. for r, h_ in self.in_adj[v]:
  90. T.append((r, 1, colour[h_])) # 1 = incoming (\u2191)
  91. T.sort() # canonical ordering
  92. key = (colour[v], tuple(T))
  93. if key not in bucket:
  94. bucket[key] = len(bucket) + 1 # new colour ID (start at 1)
  95. next_colour[v] = bucket[key]
  96. colour = next_colour # move to next iteration
  97. _print_prog(H)
  98. if progress:
  99. print() # newline after progress bar
  100. average_entropy = sum(entropies) / len(entropies)
  101. c_ratio, c_nwlec = _normalise_scores(entropies, n)
  102. if return_full:
  103. return average_entropy, entropies, c_ratio, c_nwlec
  104. return average_entropy, c_ratio, c_nwlec