You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

c_swklf.py 9.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. from __future__ import annotations
  2. import math
  3. from collections import defaultdict
  4. from dataclasses import dataclass
  5. from typing import TYPE_CHECKING, Callable, Iterable
  6. import torch
  7. from torch import Tensor
  8. from .base_metric import BaseMetric
  9. if TYPE_CHECKING:
  10. from data.kg_dataset import KGDataset
  11. from models.base_model import ERModel
  12. @torch.no_grad()
  13. def _log_prob_distribution(
  14. fn: Callable[[Tensor], Tensor],
  15. query: Tensor,
  16. batch_size: int,
  17. ) -> Iterable[Tensor]:
  18. """
  19. Yield *log*-probability batches for an arbitrary distribution helper.
  20. Parameters
  21. ----------
  22. fn :
  23. A bound method of *model* returning **log**-probabilities, e.g.
  24. ``model.tail_distribution(..., log=True)``.
  25. query :
  26. LongTensor of shape (N, 2) containing the query IDs that `fn` expects.
  27. batch_size :
  28. Mini-batch size. Choose according to available GPU/CPU memory.
  29. """
  30. for i in range(0, len(query), batch_size):
  31. yield fn(query[i : i + batch_size]) # (B, |Candidates|)
  32. def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor:
  33. """
  34. KL(q || p) where q is uniform over *true_index_sets*.
  35. Parameters
  36. ----------
  37. log_p :
  38. Tensor of shape (B, N) — log-probabilities from the model.
  39. true_index_sets :
  40. List (length B). Element *b* contains the list of **indices**
  41. that are correct for row *b*.
  42. Returns
  43. -------
  44. kl : Tensor, shape (B,) — KL divergence per row (natural log).
  45. """
  46. rows: list[Tensor] = []
  47. for lp_row, idx in zip(log_p, true_index_sets):
  48. k = len(idx)
  49. rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p]
  50. return torch.tensor(rows, device=log_p.device)
  51. @dataclass(slots=True)
  52. class _Accum:
  53. """Simple accumulator for ∑ value and ∑ weight."""
  54. tot_v: float = 0.0
  55. tot_w: float = 0.0
  56. def update(self, value: float, weight: float) -> None:
  57. self.tot_v += value * weight
  58. self.tot_w += weight
  59. @property
  60. def mean(self) -> float:
  61. return self.tot_v / self.tot_w if self.tot_w else float("nan")
  62. class CSWKLF(BaseMetric):
  63. """
  64. C-SWKLF metric for evaluating knowledge graph embeddings.
  65. This metric is used to evaluate the quality of knowledge graph embeddings
  66. by computing the C-SWKLF score.
  67. """
  68. def __init__(self, dataset: KGDataset, model: ERModel) -> None:
  69. super().__init__(dataset)
  70. self.model = model
  71. @torch.no_grad()
  72. def compute(
  73. self,
  74. alpha: float = 1 / 3,
  75. beta: float = 1 / 3,
  76. gamma: float = 1 / 3,
  77. batch_size: int = 1024,
  78. slice_size: int | None = 2048,
  79. device: torch.device | str | None = None,
  80. ) -> float:
  81. """
  82. Compute *Comprehensive Structural-Weighted KL-Fitness*.
  83. Parameters
  84. ----------
  85. model :
  86. A trained PyKEEN ERModel **with** the three distribution helpers.
  87. kg :
  88. `pykeen.triples.KGInfo` instance providing `mapped_triples`.
  89. alpha, beta, gamma :
  90. Non-negative weights for the three query types, default 1/3 each.
  91. Must sum to *1*.
  92. batch_size :
  93. Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM.
  94. slice_size :
  95. Forwarded to `tail_distribution` / `head_distribution` /
  96. `relation_distribution`. `None` disables slicing.
  97. device :
  98. Where to do the computation. `None` ⇒ first param's device.
  99. Returns
  100. -------
  101. score : float in (0, 1]
  102. Higher = model closer to empirical distributions across *all* directions.
  103. """
  104. # --------------------------------------------------------------------- #
  105. # Preparation #
  106. # --------------------------------------------------------------------- #
  107. assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1."
  108. model_device = next(self.model.parameters()).device
  109. if device is None:
  110. device = model_device
  111. triples = self.dataset.triples.to(device) # (|T|, 3)
  112. heads, rels, tails = triples.t() # (|T|,)
  113. # Build index structures --------------------------------------------------
  114. tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list)
  115. heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list)
  116. rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list)
  117. for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()):
  118. tails_by_hr[(h, r)].append(t)
  119. heads_by_rt[(r, t)].append(h)
  120. rels_by_ht[(h, t)].append(r)
  121. # Structural entropies & query lists ---------------------------------------
  122. H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r
  123. H_head: dict[int, _Accum] = defaultdict(_Accum)
  124. tail_queries: list[tuple[int, int]] = [] # (h,r)
  125. head_queries: list[tuple[int, int]] = [] # (r,t)
  126. rel_queries: list[tuple[int, int]] = [] # (h,t)
  127. # --- tails --------------------------------------------------------------
  128. for (h, r), ts in tails_by_hr.items():
  129. k = len(ts)
  130. H_tail[r].update(math.log(k), 1.0)
  131. tail_queries.append((h, r))
  132. # --- heads --------------------------------------------------------------
  133. for (r, t), hs in heads_by_rt.items():
  134. k = len(hs)
  135. H_head[r].update(math.log(k), 1.0)
  136. head_queries.append((r, t))
  137. # --- relations ----------------------------------------------------------
  138. H_rel_accum = _Accum()
  139. for (h, t), rs in rels_by_ht.items():
  140. k = len(rs)
  141. H_rel_accum.update(math.log(k), 1.0)
  142. rel_queries.append((h, t))
  143. H_struct_tail = {r: acc.mean for r, acc in H_tail.items()}
  144. H_struct_head = {r: acc.mean for r, acc in H_head.items()}
  145. # H_struct_rel = H_rel_accum.mean # scalar
  146. # --------------------------------------------------------------------- #
  147. # KL divergences #
  148. # --------------------------------------------------------------------- #
  149. def _avg_kl_per_relation(
  150. queries: list[tuple[int, int]],
  151. true_idx_map: dict[tuple[int, int], list[int]],
  152. distribution_fn: Callable[[Tensor], Tensor],
  153. num_relations: bool = False, # switch to know output size
  154. ) -> dict[int, float]:
  155. """
  156. Compute *average* KL(q || p) **per relation**.
  157. Works for the tail & head conditionals.
  158. """
  159. kl_sum: dict[int, float] = defaultdict(float)
  160. count: dict[int, int] = defaultdict(int)
  161. query_tensor = torch.tensor(queries, dtype=torch.long, device=device)
  162. for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size):
  163. # slice_size handled inside distribution_fn
  164. # log_p : (B, |Candidates|)
  165. b = log_p.shape[0]
  166. # Map back to current slice of queries
  167. slice_queries = queries[:b]
  168. queries[:] = queries[b:] # consume
  169. true_sets = [true_idx_map[q] for q in slice_queries]
  170. kl_row = _kl_uniform(log_p, true_sets) # (B,)
  171. for (x, r), kl_val in zip(slice_queries, kl_row.tolist()):
  172. rel_id = r if not num_relations else x
  173. kl_sum[rel_id] += kl_val
  174. count[rel_id] += 1
  175. return {r: kl_sum[r] / count[r] for r in kl_sum}
  176. # --- tails --------------------------------------------------------------
  177. tail_queries_copy = tail_queries.copy()
  178. D_tail = _avg_kl_per_relation(
  179. tail_queries_copy,
  180. tails_by_hr,
  181. lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s),
  182. )
  183. S_tail = {r: math.exp(-d) for r, d in D_tail.items()}
  184. # --- heads --------------------------------------------------------------
  185. head_queries_copy = head_queries.copy()
  186. D_head = _avg_kl_per_relation(
  187. head_queries_copy,
  188. heads_by_rt,
  189. lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s),
  190. num_relations=True,
  191. )
  192. S_head = {r: math.exp(-d) for r, d in D_head.items()}
  193. # --- relations (single global number) ----------------------------------
  194. D_rel_sum = 0.0
  195. for log_p in _log_prob_distribution(
  196. lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s),
  197. torch.tensor(rel_queries, dtype=torch.long, device=device),
  198. batch_size,
  199. ):
  200. slice_size_batch = log_p.shape[0]
  201. slice_queries = rel_queries[:slice_size_batch]
  202. rel_queries[:] = rel_queries[slice_size_batch:]
  203. true_sets = [rels_by_ht[q] for q in slice_queries]
  204. D_rel_sum += _kl_uniform(log_p, true_sets).sum().item()
  205. D_rel = D_rel_sum / H_rel_accum.tot_w
  206. S_rel = math.exp(-D_rel)
  207. # --------------------------------------------------------------------- #
  208. # Weighted aggregation → C-SWKLF #
  209. # --------------------------------------------------------------------- #
  210. def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float:
  211. num = sum(weight[r] * score[r] for r in score)
  212. den = sum(weight[r] for r in score)
  213. return num / den if den else float("nan")
  214. sw_tail = _weighted_mean(S_tail, H_struct_tail)
  215. sw_head = _weighted_mean(S_head, H_struct_head)
  216. # Relations: numerator & denominator cancel
  217. sw_rel = S_rel
  218. return alpha * sw_tail + beta * sw_head + gamma * sw_rel