# from __future__ import annotations # import math # import os # import random # import threading # from collections import defaultdict # from concurrent.futures import ThreadPoolExecutor, as_completed # from typing import Dict, List, Tuple # import torch # from data.kg_dataset import KGDataset # from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in # from tools import get_pretty_logger # logger = get_pretty_logger(__name__) # # -------------------------------------------------------------------- # # Edge-editing primitives # # -------------------------------------------------------------------- # # -- additions -------------------- # def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): # for _ in range(1000): # h = rng.randrange(n_ent) # t = rng.randrange(n_ent) # if colours[h] != colours[t]: # r = rng.randrange(n_rel) # triples.append((h, r, t)) # out_adj[h].append((r, t)) # in_adj[t].append((r, h)) # return # def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): # σ = rng.choice(colours) # τ = rng.choice(colours) # h = colours.index(σ) # t = colours.index(τ) # triples.append((h, 0, t)) # out_adj[h].append((0, t)) # in_adj[t].append((0, h)) # # -- removals -------------------- # def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): # cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] # if cand: # idx = rng.choice(cand) # h, r, t = triples.pop(idx) # out_adj[h].remove((r, t)) # in_adj[t].remove((r, h)) # def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): # sig_rel = defaultdict(set) # for h, r, t in triples: # σ, τ = colours[h], colours[t] # sig_rel[(σ, τ)].add(r) # cand = [] # for i, (h, r, t) in enumerate(triples): # if len(sig_rel[(colours[h], colours[t])]) == 1: # cand.append(i) # if cand: # idx = rng.choice(cand) # h, r, t = triples.pop(idx) # out_adj[h].remove((r, t)) # in_adj[t].remove((r, h)) # def _search_worker( # seed: int, # triples_init: List[Tuple[int, int, int]], # triples_factory, # n_ent: int, # n_rel: int, # depth: int, # lo: float, # hi: float, # max_iters: int, # ) -> List[Tuple[int, int, int]] | None: # """ # Run the exact same hill‑climb that `tune_crec()` did, # but entirely in this process. Return the edited triples # once c falls in [lo, hi]; return None if we used up all iterations. # """ # rng = random.Random(seed) # triples = triples_init.copy().tolist() # for it in range(max_iters): # # WL‑CREC is *only* recomputed every 1000 edits, exactly like before # if it % 1000 == 0: # dataset = KGDataset( # triples_factory, # triples=torch.tensor(triples, dtype=torch.long), # num_entities=n_ent, # num_relations=n_rel, # ) # wl = WLCREC(dataset) # colours = wl.wl_colours(depth) # *_, c, _ = wl.compute(H=5) # unchanged API # if lo <= c <= hi: # success # logger.info("[seed %d] hit %.4f in %d edits", seed, c, it) # return triples # # ---------- identical edit logic ---------- # if c < lo: # if rng.random() < 0.5: # rem_det_edge(triples, colours, depth, rng) # else: # add_div_edge(triples, colours, depth, n_ent, n_rel, rng) # else: # c > hi # if rng.random() < 0.5: # rem_div_edge(triples, colours, depth, rng) # else: # add_det_edge(triples, colours, depth, n_rel, rng) # # ------------------------------------------ # return None # used up our budget # # -------------------------------------------------------------------- # # Unified tuner # # -------------------------------------------------------------------- # def tune_crec( # wl_crec: WLCREC, # n_ent: int, # n_rel: int, # depth: int, # lo: float, # hi: float, # max_iters: int = 80_000, # seed: int = 42, # ): # triples = wl_crec.dataset.triples.tolist() # rng = random.Random(seed) # for it in range(max_iters): # print(f"\r[iter {it + 1:5d}] ", end="") # if it % 1000 == 0: # dataset = KGDataset( # wl_crec.dataset.triples_factory, # triples=torch.tensor(triples, dtype=torch.long), # num_entities=n_ent, # num_relations=n_rel, # ) # tmp_wl_crec = WLCREC(dataset) # # colours = wl_colours(triples, n_ent, depth) # colours = tmp_wl_crec.wl_colours(depth) # _, _, _, _, c, _ = tmp_wl_crec.compute(H=5) # if lo <= c <= hi: # logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples)) # return triples # if c < lo: # # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying # if rng.random() < 0.5: # rem_det_edge(triples, colours, depth, rng) # else: # add_div_edge(triples, colours, depth, n_ent, n_rel, rng) # # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic # elif rng.random() < 0.5: # rem_div_edge(triples, colours, depth, rng) # else: # add_det_edge(triples, colours, depth, n_rel, rng) # if (it + 1) % 10000 == 1: # logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples)) # raise RuntimeError("Exceeded max iterations without hitting target band.") # def _edit_batch( # worker_id: int, # n_edits: int, # colours: List[int], # crec: float, # target_lo: float, # target_hi: float, # depth: int, # n_ent: int, # n_rel: int, # seed_base: int, # ): # """ # Perform `n_edits` topology modifications *locally* and return # (added_triples, removed_triples) lists. # """ # rng = random.Random(seed_base + worker_id) # local_added, local_removed = [], [] # # local graph views: only sizes matter for edit selection # for _ in range(n_edits): # if crec < target_lo: # need ↑ WL-CREC # if rng.random() < 0.5: # • remove deterministic # # We cannot remove by index safely without the global list; # # choose a *signature* and remember intention to delete: # local_removed.append(("det", rng.random())) # else: # • add diversifying # h = rng.randrange(n_ent) # t = rng.randrange(n_ent) # while colours[h] == colours[t]: # h = rng.randrange(n_ent) # t = rng.randrange(n_ent) # r = rng.randrange(n_rel) # local_added.append((h, r, t)) # else: # need ↓ WL-CREC # if rng.random() < 0.5: # • remove diversifying # local_removed.append(("div", rng.random())) # else: # • add deterministic # σ = rng.choice(colours) # τ = rng.choice(colours) # h = colours.index(σ) # t = colours.index(τ) # local_added.append((h, 0, t)) # return local_added, local_removed # def wl_colours(triples, n_ent, depth): # # build adjacency once # out_adj, in_adj = defaultdict(list), defaultdict(list) # for h, r, t in triples: # out_adj[h].append((r, t)) # in_adj[t].append((r, h)) # colours_rounds = [[0] * n_ent] # round-0 colours # for h in range(1, depth + 1): # prev = colours_rounds[-1] # # 1) build textual signatures in parallel # def sig(v): # neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ # ("↑", r, prev[u]) for r, u in in_adj.get(v, []) # ] # neigh.sort() # return (prev[v], tuple(neigh)) # with ThreadPoolExecutor() as tpe: # cheap threads inside worker # sigs = list(tpe.map(sig, range(n_ent))) # # 2) assign deterministic colour IDs # sig2id: Dict[Tuple, int] = {} # next_round = [0] * n_ent # fresh = 0 # for v, sg in enumerate(sigs): # cid = sig2id.setdefault(sg, fresh) # if cid == fresh: # fresh += 1 # next_round[v] = cid # colours_rounds.append(next_round) # depth_colours = colours_rounds[-1] # return depth_colours # def _metric_worker(args): # triples, n_ent, n_rel, depth = args # dataset = KGDataset( # triples_factory=None, # Not used in this context # triples=torch.tensor(triples, dtype=torch.long), # num_entities=n_ent, # num_relations=n_rel, # ) # wl_crec = WLCREC(dataset) # _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False) # colours = wl_colours(triples, n_ent, depth) # return c, colours # def tune_crec_parallel_edits( # triples_init, # n_ent: int, # n_rel: int, # depth: int, # target_lo: float, # target_hi: float, # max_iters: int = 80_000, # metric_every: int = 100, # n_workers: int = max(20, math.ceil(os.cpu_count() / 2)), # seed: int = 42, # ): # # -------- shared mutable state (main thread owns it) -------- # triples = triples_init.tolist() # out_adj, in_adj = defaultdict(list), defaultdict(list) # for h, r, t in triples: # out_adj[h].append((r, t)) # in_adj[t].append((r, h)) # pool = ThreadPoolExecutor(max_workers=n_workers) # metric_lock = threading.Lock() # exactly one metric at a time # rng_global = random.Random(seed) # # ----- first metric checkpoint ----- # crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) # edit_budget_total = 0 # for it in range(0, max_iters, metric_every): # # ========================================================= # # 1. PARALLEL EDIT STAGE (metric_every edits in total) # # ========================================================= # futures = [] # edits_per_worker = metric_every // n_workers # extra = metric_every % n_workers # for wid in range(n_workers): # n_edits = edits_per_worker + (1 if wid < extra else 0) # futures.append( # pool.submit( # _edit_batch, # wid, # n_edits, # colours, # crec, # target_lo, # target_hi, # depth, # n_ent, # n_rel, # seed, # ) # ) # # merge when workers finish # for fut in as_completed(futures): # added, removed_specs = fut.result() # # --- apply additions immediately (cheap, conflict-free) # for h, r, t in added: # triples.append((h, r, t)) # out_adj[h].append((r, t)) # in_adj[t].append((r, h)) # # --- apply removals: interpret the spec on *current* graph # for typ, randv in removed_specs: # if typ == "div": # idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] # else: # 'det' # sig_rel = defaultdict(set) # for h, r, t in triples: # sig_rel[(colours[h], colours[t])].add(r) # idxs = [ # i # for i, (h, r, t) in enumerate(triples) # if len(sig_rel[(colours[h], colours[t])]) == 1 # ] # if idxs: # victim = idxs[int(randv * len(idxs))] # h, r, t = triples.pop(victim) # out_adj[h].remove((r, t)) # in_adj[t].remove((r, h)) # edit_budget_total += metric_every # # ========================================================= # # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised) # # ========================================================= # # if edit_budget_total % (10 * metric_every) == 0: # if True: # with metric_lock: # crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) # logging.info( # "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples) # ) # if target_lo <= crec <= target_hi: # logging.info("Target band reached.") # pool.shutdown(wait=True) # return triples # pool.shutdown(wait=True) # raise RuntimeError("Exceeded max iteration budget without success.") # from __future__ import annotations # import logging # import random # from collections import defaultdict # from concurrent.futures import ThreadPoolExecutor, as_completed # from typing import List, Tuple # import torch # from data.kg_dataset import KGDataset # from metrics.wlcrec import WLCREC # from tools import get_pretty_logger # logger = get_pretty_logger(__name__) # # ------------ additions ------------ # def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random): # for _ in range(1000): # h, t = rng.randrange(n_ent), rng.randrange(n_ent) # if colours[h] != colours[t]: # r = rng.randrange(n_rel) # return ("add", (h, r, t)) # return None # fell through – extremely rare # def propose_add_det(colours, n_rel: int, rng: random.Random): # σ, τ = rng.choice(colours), rng.choice(colours) # h, t = colours.index(σ), colours.index(τ) # return ("add", (h, 0, t)) # rel 0 = deterministic # # ------------ removals ------------ # def propose_rem_div(triples: List, colours, rng: random.Random): # cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]] # return ("rem", rng.choice(cand)) if cand else None # def propose_rem_det(triples: List, colours, rng: random.Random): # sig_rel = defaultdict(set) # for h, r, t in triples: # sig_rel[(colours[h], colours[t])].add(r) # cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1] # return ("rem", rng.choice(cand)) if cand else None # def make_edit_proposal( # triples_snapshot: List, # colours_snapshot, # c: float, # lo: float, # hi: float, # n_ent: int, # n_rel: int, # depth: int, # still here for future use / signatures # seed: int, # ): # """Return exactly one ('add' | 'rem', triple) proposal or None.""" # rng = random.Random(seed) # # -- decide which kind of edit we want, *given* the current c ------------ # if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying) # chooser = (propose_rem_det, propose_add_div) # else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic) # chooser = (propose_rem_div, propose_add_det) # op = rng.choice(chooser) # return ( # op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng) # if op.__name__.startswith("propose_add") # else op(triples_snapshot, colours_snapshot, rng) # ) # def tune_crec_parallel_edits( # triples: List, # n_ent: int, # n_rel: int, # depth: int, # lo: float, # hi: float, # *, # max_iters: int = 80_000, # edits_per_eval: int = 1000, # == old “if it % 1000 == 0” # batch_size: int = 256, # how many proposals we farm out at once # max_workers: int = 4, # seed: int = 42, # ) -> List: # rng_global = random.Random(seed) # triples = set(triples) # deduplicate if needed # with ThreadPoolExecutor(max_workers=max_workers) as pool: # proposal_seed = seed * 997 # deterministic but different stream # # edit_counter = 0 # # while edit_counter < max_iters: # # # ----------------- expensive part (single‑thread) ---------------- # # dataset = KGDataset( # # triples_factory=None, # Not used in this context # # triples=torch.tensor(triples, dtype=torch.long), # # num_entities=n_ent, # # num_relations=n_rel, # # ) # # tmp = WLCREC(dataset) # # colours = tmp.wl_colours(depth) # # *_, c, _ = tmp.compute(H=5) # # # ----------------------------------------------------------------- # # if lo <= c <= hi: # # logger.info( # # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples) # # ) # # return triples # # ============ parallel block: just make `edits_per_eval` proposals # needed = min(edits_per_eval, max_iters - edit_counter) # proposals = [] # while len(proposals) < needed: # # launch a batch of workers # futs = [ # pool.submit( # make_edit_proposal, # triples, # colours, # c, # lo, # hi, # n_ent, # n_rel, # depth, # proposal_seed + i, # ) # for i in range(batch_size) # ] # for f in as_completed(futs): # prop = f.result() # if prop is not None: # proposals.append(prop) # if len(proposals) == needed: # break # proposal_seed += batch_size # move RNG window forward # # -------------- apply the gathered proposals *sequentially* ------- # for kind, trp in proposals: # if kind == "add": # triples.append(trp) # else: # "rem" # try: # triples.remove(trp) # except ValueError: # pass # already gone – benign collision # # ----------------------------------------------------------------- # edit_counter += needed # if edit_counter % 1_000 == 0: # logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples)) # raise RuntimeError("Exceeded max_iters without hitting target band.") import os import random import threading from collections import defaultdict from typing import Dict, List, Tuple # -------------------------------------------------------------------- # Logging helper (replace with your own if you prefer) -------------- # -------------------------------------------------------------------- from tools import get_pretty_logger logging = get_pretty_logger(__name__) # -------------------------------------------------------------------- # Original edge‑editing primitives (unchanged) ---------------------- # -------------------------------------------------------------------- def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): for _ in range(1000): h = rng.randrange(n_ent) t = rng.randrange(n_ent) if colours[h] != colours[t]: r = rng.randrange(n_rel) triples.append((h, r, t)) out_adj[h].append((r, t)) in_adj[t].append((r, h)) return def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): σ = rng.choice(colours) τ = rng.choice(colours) h = colours.index(σ) t = colours.index(τ) triples.append((h, 0, t)) out_adj[h].append((0, t)) in_adj[t].append((0, h)) def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] if cand: idx = rng.choice(cand) h, r, t = triples.pop(idx) out_adj[h].remove((r, t)) in_adj[t].remove((r, h)) def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): sig_rel = defaultdict(set) for h, r, t in triples: σ, τ = colours[h], colours[t] sig_rel[(σ, τ)].add(r) cand = [] for i, (h, r, t) in enumerate(triples): if len(sig_rel[(colours[h], colours[t])]) == 1: cand.append(i) if cand: idx = rng.choice(cand) h, r, t = triples.pop(idx) out_adj[h].remove((r, t)) in_adj[t].remove((r, h)) def _worker( *, worker_id: int, rng: random.Random, triples: List[Tuple[int, int, int]], out_adj: Dict[int, List[Tuple[int, int]]], in_adj: Dict[int, List[Tuple[int, int]]], colours: List[int], n_ent: int, n_rel: int, depth: int, c: float, lo: float, hi: float, max_iters: int, state_lock: threading.Lock, stop_event: threading.Event, ): """One thread: mutate the *shared* structures until success/stop.""" for it in range(max_iters): if stop_event.is_set(): return # someone else finished ─ exit early with state_lock: # protect the shared graph # if lo <= c <= hi: # logging.info( # "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)", # worker_id, # it, # c, # len(triples), # ) # stop_event.set() # return # Choose and apply one edit ----------------------------- if c < lo: # need ↑ CREC if rng.random() < 0.5: rem_det_edge(triples, out_adj, in_adj, colours, depth, rng) else: add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng) elif rng.random() < 0.5: rem_div_edge(triples, out_adj, in_adj, colours, depth, rng) else: add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng) logging.warning("[worker %d] reached max_iters", worker_id) stop_event.set() return # -------------------------------------------------------------------- # Public API -------------------------------------------------------- # -------------------------------------------------------------------- def fast_tune_crec( triples: List, colours: List, n_ent: int, n_rel: int, depth: int, c: float, lo: float, hi: float, max_iters: int = 1000, max_workers: int | None = None, seeds: List[int] | None = None, ) -> List[Tuple[int, int, int]]: """Tune WL‑CREC with *shared* triples using multiple threads. Returns the **same list instance** that was passed in – already modified in place by the winning thread. """ if max_workers is None: # max_workers = os.cpu_count() or 4 max_workers = 4 if seeds is None: seeds = [42 + i for i in range(max_workers)] assert len(seeds) >= max_workers, "Need at least one seed per worker" # Prepare adjacency once (shared) -------------------------------- out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) for h, r, t in triples: out_adj[h].append((r, t)) in_adj[t].append((r, h)) state_lock = threading.Lock() stop_event = threading.Event() logging.info( "Launching %d threads on shared triples (target %.3f–%.3f)", max_workers, lo, hi, ) threads: List[threading.Thread] = [] for wid in range(max_workers): t = threading.Thread( name=f"tune‑crec‑worker‑{wid}", target=_worker, kwargs=dict( worker_id=wid, rng=random.Random(seeds[wid]), triples=triples, out_adj=out_adj, in_adj=in_adj, colours=colours, n_ent=n_ent, n_rel=n_rel, depth=depth, c=c, lo=lo, hi=hi, max_iters=max_iters, state_lock=state_lock, stop_event=stop_event, ), daemon=False, ) threads.append(t) t.start() for t in threads: t.join() if not stop_event.is_set(): raise RuntimeError("No thread converged – try increasing max_iters or widen band") return triples