123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- from __future__ import annotations
-
- import logging
- import random
- from math import log
- from typing import Dict, List, Tuple
-
- import torch
- from torch import Tensor
-
- from tools import get_pretty_logger
-
- logging = get_pretty_logger(__name__)
-
-
- Entity = int
- Relation = int
- Triple = Tuple[Entity, Relation, Entity]
- Color = int
-
-
- # ---------------------------------------------------------------------
- # Weisfeiler–Lehman colouring (GPU friendly)
- # ---------------------------------------------------------------------
- def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]:
- """
- GPU implementation of classic Weisfeiler‑Lehman refinement.
-
- 1. Each node starts with colour 0.
- 2. At iteration h we hash the multiset of
- ⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h))
- signatures into new colour ids.
- 3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and
- immaterial for CREC (only *relative* colours matter).
- """
- h, r, t = triples.unbind(dim=1) # (|T|,)
- colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0
-
- # pre‑compute 64‑bit mix constants once
- mix_r = (
- torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long)
- * 1_146_189_683_093_321_123
- )
- mix_c = 8_636_673_225_737_527_201
-
- for _ in range(depth):
- prev = colours[-1] # (|V|,)
-
- # signatures for outgoing edges
- sig_out = mix_r[r] ^ (prev[t] * mix_c)
- # signatures for incoming edges
- sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15
-
- # bucket signatures back to the source/target nodes
- col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out)
- col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in)
-
- # combine ↓ and ↑ multiset hashes with current colour
- raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1)
-
- # re‑map to dense consecutive ids with torch.unique
- uniq, new = torch.unique(raw, sorted=True, return_inverse=True)
- colours.append(new)
-
- return colours # length = depth+1, each (|V|,) long
-
-
- # ---------------------------------------------------------------------
- # WL‑CREC metric
- # ---------------------------------------------------------------------
- def wl_crec_gpu(
- triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device
- ) -> float:
- log_n = log(max(n_ent, 2))
-
- # ------- pattern‑diversity C -------------------------------------
- C = 0.0
- for c in colours: # (|V|,)
- # histogram via bincount on GPU
- hist = torch.bincount(c).float()
- p = hist / n_ent
- C += -(p * torch.log(p.clamp_min(1e-30))).sum().item()
- C /= (depth + 1) * log_n
-
- # ------- residual entropy H_c ------------------------------------
- h, r, t = triples.unbind(dim=1)
- sigma, tau = colours[depth][h], colours[depth][t] # (|T|,)
-
- # encode (sigma,tau) pairs into a single 64‑bit key
- sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64)
-
- # total count per signature
- sig_unique, sig_inv, sig_counts = torch.unique(
- sig_keys, return_inverse=True, return_counts=True, sorted=False
- )
-
- # build 2‑D contingency table counts[(sigma,tau), r]
- m = triples.size(0)
- keys_2d = sig_inv * n_rel + r
- _, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False)
-
- # rel_counts is aligned with the *compact* key list; rebuild dense tensor
- rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long)
- rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long))
-
- # conditional entropy per (sigma,tau)
- p_r = rc_dense.float() / sig_counts.unsqueeze(1).float()
- inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,)
-
- Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2))
-
- return C * Hc
-
-
- # ---------------------------------------------------------------------
- # delta‑entropy helpers (still CPU side for clarity; cost is negligible)
- # ---------------------------------------------------------------------
- # The deterministic delta‑formulas depend on small dictionaries and are evaluated
- # on ≤256 candidate edges each iteration – copy them from the original script.
- inner_entropy = ... # unchanged
- delta_h_cond_remove = ... # unchanged
- delta_h_cond_add = ... # unchanged
-
-
- # ---------------------------------------------------------------------
- # Greedy delta‑search driver
- # ---------------------------------------------------------------------
- def greedy_delta_tune_gpu(
- triples: List[Triple],
- n_ent: int,
- n_rel: int,
- depth: int,
- lower: float,
- upper: float,
- device: torch.device,
- max_iters: int = 40_000,
- sample_size: int = 256,
- seed: int = 0,
- ) -> List[Triple]:
- # book‑keeping in Python lists (cheap); heavy math in torch on GPU
- rng = random.Random(seed)
-
- # cache torch version of triples for fast metric eval
- def triples_to_tensor(ts: List[Triple]) -> Tensor:
- return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False)
-
- for it in range(max_iters):
- tri_tensor = triples_to_tensor(triples)
- colours = wl_colours_gpu(tri_tensor, n_ent, depth, device)
- crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device)
-
- if lower <= crec <= upper:
- logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples))
- return triples
-
- # ----------------------------------------------------------------
- # build signature statistics (CPU, cheap: O(|T|))
- # ----------------------------------------------------------------
- depth_col = colours[depth].cpu()
- sig_cnt: Dict[Tuple[int, int], int] = {}
- rel_cnt: Dict[Tuple[int, int, int], int] = {}
-
- for h_idx, (h, r, t) in enumerate(triples):
- sigma, tau = int(depth_col[h]), int(depth_col[t])
- sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1
- rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1
-
- det_edges, div_edges = [], []
- for idx, (h, r, t) in enumerate(triples):
- sigma, tau = int(depth_col[h]), int(depth_col[t])
- if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1:
- det_edges.append(idx)
- if sigma != tau:
- div_edges.append(idx)
-
- # ----------------------------------------------------------------
- # candidate generation + best edit selection
- # ----------------------------------------------------------------
- total = len(triples)
- target_high = crec < lower # need to raise WL‑CREC?
- candidates = []
-
- if target_high and det_edges:
- rng.shuffle(det_edges)
- for idx in det_edges[:sample_size]:
- h, r, t = triples[idx]
- sig = (int(depth_col[h]), int(depth_col[t]))
- delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
- if delta > 0:
- candidates.append(("remove", idx, delta))
-
- elif not target_high and det_edges:
- rng.shuffle(det_edges)
- for idx in det_edges[:sample_size]:
- h, r, t = triples[idx]
- sig = (int(depth_col[h]), int(depth_col[t]))
- delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
- if delta < 0:
- candidates.append(("add", (h, r, t), delta))
-
- # fall‑back heuristics
- if not candidates:
- if target_high and div_edges:
- idx = rng.choice(div_edges)
- candidates.append(("remove", idx, 1e-9))
- elif not target_high:
- idx = rng.choice(det_edges) if det_edges else rng.randrange(total)
- h, r, t = triples[idx]
- candidates.append(("add", (h, r, t), -1e-9))
-
- best = (
- max(candidates, key=lambda x: x[2])
- if target_high
- else min(candidates, key=lambda x: x[2])
- )
-
- # apply edit
- if best[0] == "remove":
- triples.pop(best[1])
- else:
- triples.append(best[1])
-
- if (it + 1) % 1_000 == 0:
- logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples))
-
- raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.")
|