from __future__ import annotations import logging import random from math import log from typing import Dict, List, Tuple import torch from torch import Tensor from tools import get_pretty_logger logging = get_pretty_logger(__name__) Entity = int Relation = int Triple = Tuple[Entity, Relation, Entity] Color = int # --------------------------------------------------------------------- # Weisfeiler–Lehman colouring (GPU friendly) # --------------------------------------------------------------------- def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]: """ GPU implementation of classic Weisfeiler‑Lehman refinement. 1. Each node starts with colour 0. 2. At iteration h we hash the multiset of ⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h)) signatures into new colour ids. 3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and immaterial for CREC (only *relative* colours matter). """ h, r, t = triples.unbind(dim=1) # (|T|,) colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0 # pre‑compute 64‑bit mix constants once mix_r = ( torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long) * 1_146_189_683_093_321_123 ) mix_c = 8_636_673_225_737_527_201 for _ in range(depth): prev = colours[-1] # (|V|,) # signatures for outgoing edges sig_out = mix_r[r] ^ (prev[t] * mix_c) # signatures for incoming edges sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15 # bucket signatures back to the source/target nodes col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out) col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in) # combine ↓ and ↑ multiset hashes with current colour raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1) # re‑map to dense consecutive ids with torch.unique uniq, new = torch.unique(raw, sorted=True, return_inverse=True) colours.append(new) return colours # length = depth+1, each (|V|,) long # --------------------------------------------------------------------- # WL‑CREC metric # --------------------------------------------------------------------- def wl_crec_gpu( triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device ) -> float: log_n = log(max(n_ent, 2)) # ------- pattern‑diversity C ------------------------------------- C = 0.0 for c in colours: # (|V|,) # histogram via bincount on GPU hist = torch.bincount(c).float() p = hist / n_ent C += -(p * torch.log(p.clamp_min(1e-30))).sum().item() C /= (depth + 1) * log_n # ------- residual entropy H_c ------------------------------------ h, r, t = triples.unbind(dim=1) sigma, tau = colours[depth][h], colours[depth][t] # (|T|,) # encode (sigma,tau) pairs into a single 64‑bit key sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64) # total count per signature sig_unique, sig_inv, sig_counts = torch.unique( sig_keys, return_inverse=True, return_counts=True, sorted=False ) # build 2‑D contingency table counts[(sigma,tau), r] m = triples.size(0) keys_2d = sig_inv * n_rel + r _, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False) # rel_counts is aligned with the *compact* key list; rebuild dense tensor rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long) rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long)) # conditional entropy per (sigma,tau) p_r = rc_dense.float() / sig_counts.unsqueeze(1).float() inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,) Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2)) return C * Hc # --------------------------------------------------------------------- # delta‑entropy helpers (still CPU side for clarity; cost is negligible) # --------------------------------------------------------------------- # The deterministic delta‑formulas depend on small dictionaries and are evaluated # on ≤256 candidate edges each iteration – copy them from the original script. inner_entropy = ... # unchanged delta_h_cond_remove = ... # unchanged delta_h_cond_add = ... # unchanged # --------------------------------------------------------------------- # Greedy delta‑search driver # --------------------------------------------------------------------- def greedy_delta_tune_gpu( triples: List[Triple], n_ent: int, n_rel: int, depth: int, lower: float, upper: float, device: torch.device, max_iters: int = 40_000, sample_size: int = 256, seed: int = 0, ) -> List[Triple]: # book‑keeping in Python lists (cheap); heavy math in torch on GPU rng = random.Random(seed) # cache torch version of triples for fast metric eval def triples_to_tensor(ts: List[Triple]) -> Tensor: return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False) for it in range(max_iters): tri_tensor = triples_to_tensor(triples) colours = wl_colours_gpu(tri_tensor, n_ent, depth, device) crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device) if lower <= crec <= upper: logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples)) return triples # ---------------------------------------------------------------- # build signature statistics (CPU, cheap: O(|T|)) # ---------------------------------------------------------------- depth_col = colours[depth].cpu() sig_cnt: Dict[Tuple[int, int], int] = {} rel_cnt: Dict[Tuple[int, int, int], int] = {} for h_idx, (h, r, t) in enumerate(triples): sigma, tau = int(depth_col[h]), int(depth_col[t]) sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1 rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1 det_edges, div_edges = [], [] for idx, (h, r, t) in enumerate(triples): sigma, tau = int(depth_col[h]), int(depth_col[t]) if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1: det_edges.append(idx) if sigma != tau: div_edges.append(idx) # ---------------------------------------------------------------- # candidate generation + best edit selection # ---------------------------------------------------------------- total = len(triples) target_high = crec < lower # need to raise WL‑CREC? candidates = [] if target_high and det_edges: rng.shuffle(det_edges) for idx in det_edges[:sample_size]: h, r, t = triples[idx] sig = (int(depth_col[h]), int(depth_col[t])) delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) if delta > 0: candidates.append(("remove", idx, delta)) elif not target_high and det_edges: rng.shuffle(det_edges) for idx in det_edges[:sample_size]: h, r, t = triples[idx] sig = (int(depth_col[h]), int(depth_col[t])) delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) if delta < 0: candidates.append(("add", (h, r, t), delta)) # fall‑back heuristics if not candidates: if target_high and div_edges: idx = rng.choice(div_edges) candidates.append(("remove", idx, 1e-9)) elif not target_high: idx = rng.choice(det_edges) if det_edges else rng.randrange(total) h, r, t = triples[idx] candidates.append(("add", (h, r, t), -1e-9)) best = ( max(candidates, key=lambda x: x[2]) if target_high else min(candidates, key=lambda x: x[2]) ) # apply edit if best[0] == "remove": triples.pop(best[1]) else: triples.append(best[1]) if (it + 1) % 1_000 == 0: logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples)) raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.")