You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

greedy_crec.py 8.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. from __future__ import annotations
  2. import logging
  3. import random
  4. from math import log
  5. from typing import Dict, List, Tuple
  6. import torch
  7. from torch import Tensor
  8. from tools import get_pretty_logger
  9. logging = get_pretty_logger(__name__)
  10. Entity = int
  11. Relation = int
  12. Triple = Tuple[Entity, Relation, Entity]
  13. Color = int
  14. # ---------------------------------------------------------------------
  15. # Weisfeiler–Lehman colouring (GPU friendly)
  16. # ---------------------------------------------------------------------
  17. def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]:
  18. """
  19. GPU implementation of classic Weisfeiler‑Lehman refinement.
  20. 1. Each node starts with colour 0.
  21. 2. At iteration h we hash the multiset of
  22. ⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h))
  23. signatures into new colour ids.
  24. 3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and
  25. immaterial for CREC (only *relative* colours matter).
  26. """
  27. h, r, t = triples.unbind(dim=1) # (|T|,)
  28. colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0
  29. # pre‑compute 64‑bit mix constants once
  30. mix_r = (
  31. torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long)
  32. * 1_146_189_683_093_321_123
  33. )
  34. mix_c = 8_636_673_225_737_527_201
  35. for _ in range(depth):
  36. prev = colours[-1] # (|V|,)
  37. # signatures for outgoing edges
  38. sig_out = mix_r[r] ^ (prev[t] * mix_c)
  39. # signatures for incoming edges
  40. sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15
  41. # bucket signatures back to the source/target nodes
  42. col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out)
  43. col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in)
  44. # combine ↓ and ↑ multiset hashes with current colour
  45. raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1)
  46. # re‑map to dense consecutive ids with torch.unique
  47. uniq, new = torch.unique(raw, sorted=True, return_inverse=True)
  48. colours.append(new)
  49. return colours # length = depth+1, each (|V|,) long
  50. # ---------------------------------------------------------------------
  51. # WL‑CREC metric
  52. # ---------------------------------------------------------------------
  53. def wl_crec_gpu(
  54. triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device
  55. ) -> float:
  56. log_n = log(max(n_ent, 2))
  57. # ------- pattern‑diversity C -------------------------------------
  58. C = 0.0
  59. for c in colours: # (|V|,)
  60. # histogram via bincount on GPU
  61. hist = torch.bincount(c).float()
  62. p = hist / n_ent
  63. C += -(p * torch.log(p.clamp_min(1e-30))).sum().item()
  64. C /= (depth + 1) * log_n
  65. # ------- residual entropy H_c ------------------------------------
  66. h, r, t = triples.unbind(dim=1)
  67. sigma, tau = colours[depth][h], colours[depth][t] # (|T|,)
  68. # encode (sigma,tau) pairs into a single 64‑bit key
  69. sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64)
  70. # total count per signature
  71. sig_unique, sig_inv, sig_counts = torch.unique(
  72. sig_keys, return_inverse=True, return_counts=True, sorted=False
  73. )
  74. # build 2‑D contingency table counts[(sigma,tau), r]
  75. m = triples.size(0)
  76. keys_2d = sig_inv * n_rel + r
  77. _, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False)
  78. # rel_counts is aligned with the *compact* key list; rebuild dense tensor
  79. rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long)
  80. rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long))
  81. # conditional entropy per (sigma,tau)
  82. p_r = rc_dense.float() / sig_counts.unsqueeze(1).float()
  83. inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,)
  84. Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2))
  85. return C * Hc
  86. # ---------------------------------------------------------------------
  87. # delta‑entropy helpers (still CPU side for clarity; cost is negligible)
  88. # ---------------------------------------------------------------------
  89. # The deterministic delta‑formulas depend on small dictionaries and are evaluated
  90. # on ≤256 candidate edges each iteration – copy them from the original script.
  91. inner_entropy = ... # unchanged
  92. delta_h_cond_remove = ... # unchanged
  93. delta_h_cond_add = ... # unchanged
  94. # ---------------------------------------------------------------------
  95. # Greedy delta‑search driver
  96. # ---------------------------------------------------------------------
  97. def greedy_delta_tune_gpu(
  98. triples: List[Triple],
  99. n_ent: int,
  100. n_rel: int,
  101. depth: int,
  102. lower: float,
  103. upper: float,
  104. device: torch.device,
  105. max_iters: int = 40_000,
  106. sample_size: int = 256,
  107. seed: int = 0,
  108. ) -> List[Triple]:
  109. # book‑keeping in Python lists (cheap); heavy math in torch on GPU
  110. rng = random.Random(seed)
  111. # cache torch version of triples for fast metric eval
  112. def triples_to_tensor(ts: List[Triple]) -> Tensor:
  113. return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False)
  114. for it in range(max_iters):
  115. tri_tensor = triples_to_tensor(triples)
  116. colours = wl_colours_gpu(tri_tensor, n_ent, depth, device)
  117. crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device)
  118. if lower <= crec <= upper:
  119. logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples))
  120. return triples
  121. # ----------------------------------------------------------------
  122. # build signature statistics (CPU, cheap: O(|T|))
  123. # ----------------------------------------------------------------
  124. depth_col = colours[depth].cpu()
  125. sig_cnt: Dict[Tuple[int, int], int] = {}
  126. rel_cnt: Dict[Tuple[int, int, int], int] = {}
  127. for h_idx, (h, r, t) in enumerate(triples):
  128. sigma, tau = int(depth_col[h]), int(depth_col[t])
  129. sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1
  130. rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1
  131. det_edges, div_edges = [], []
  132. for idx, (h, r, t) in enumerate(triples):
  133. sigma, tau = int(depth_col[h]), int(depth_col[t])
  134. if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1:
  135. det_edges.append(idx)
  136. if sigma != tau:
  137. div_edges.append(idx)
  138. # ----------------------------------------------------------------
  139. # candidate generation + best edit selection
  140. # ----------------------------------------------------------------
  141. total = len(triples)
  142. target_high = crec < lower # need to raise WL‑CREC?
  143. candidates = []
  144. if target_high and det_edges:
  145. rng.shuffle(det_edges)
  146. for idx in det_edges[:sample_size]:
  147. h, r, t = triples[idx]
  148. sig = (int(depth_col[h]), int(depth_col[t]))
  149. delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
  150. if delta > 0:
  151. candidates.append(("remove", idx, delta))
  152. elif not target_high and det_edges:
  153. rng.shuffle(det_edges)
  154. for idx in det_edges[:sample_size]:
  155. h, r, t = triples[idx]
  156. sig = (int(depth_col[h]), int(depth_col[t]))
  157. delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
  158. if delta < 0:
  159. candidates.append(("add", (h, r, t), delta))
  160. # fall‑back heuristics
  161. if not candidates:
  162. if target_high and div_edges:
  163. idx = rng.choice(div_edges)
  164. candidates.append(("remove", idx, 1e-9))
  165. elif not target_high:
  166. idx = rng.choice(det_edges) if det_edges else rng.randrange(total)
  167. h, r, t = triples[idx]
  168. candidates.append(("add", (h, r, t), -1e-9))
  169. best = (
  170. max(candidates, key=lambda x: x[2])
  171. if target_high
  172. else min(candidates, key=lambda x: x[2])
  173. )
  174. # apply edit
  175. if best[0] == "remove":
  176. triples.pop(best[1])
  177. else:
  178. triples.append(best[1])
  179. if (it + 1) % 1_000 == 0:
  180. logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples))
  181. raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.")