You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

crec_modifier.py 25KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. # from __future__ import annotations
  2. # import math
  3. # import os
  4. # import random
  5. # import threading
  6. # from collections import defaultdict
  7. # from concurrent.futures import ThreadPoolExecutor, as_completed
  8. # from typing import Dict, List, Tuple
  9. # import torch
  10. # from data.kg_dataset import KGDataset
  11. # from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in
  12. # from tools import get_pretty_logger
  13. # logger = get_pretty_logger(__name__)
  14. # # --------------------------------------------------------------------
  15. # # Edge-editing primitives
  16. # # --------------------------------------------------------------------
  17. # # -- additions --------------------
  18. # def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng):
  19. # for _ in range(1000):
  20. # h = rng.randrange(n_ent)
  21. # t = rng.randrange(n_ent)
  22. # if colours[h] != colours[t]:
  23. # r = rng.randrange(n_rel)
  24. # triples.append((h, r, t))
  25. # out_adj[h].append((r, t))
  26. # in_adj[t].append((r, h))
  27. # return
  28. # def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng):
  29. # σ = rng.choice(colours)
  30. # τ = rng.choice(colours)
  31. # h = colours.index(σ)
  32. # t = colours.index(τ)
  33. # triples.append((h, 0, t))
  34. # out_adj[h].append((0, t))
  35. # in_adj[t].append((0, h))
  36. # # -- removals --------------------
  37. # def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng):
  38. # cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
  39. # if cand:
  40. # idx = rng.choice(cand)
  41. # h, r, t = triples.pop(idx)
  42. # out_adj[h].remove((r, t))
  43. # in_adj[t].remove((r, h))
  44. # def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng):
  45. # sig_rel = defaultdict(set)
  46. # for h, r, t in triples:
  47. # σ, τ = colours[h], colours[t]
  48. # sig_rel[(σ, τ)].add(r)
  49. # cand = []
  50. # for i, (h, r, t) in enumerate(triples):
  51. # if len(sig_rel[(colours[h], colours[t])]) == 1:
  52. # cand.append(i)
  53. # if cand:
  54. # idx = rng.choice(cand)
  55. # h, r, t = triples.pop(idx)
  56. # out_adj[h].remove((r, t))
  57. # in_adj[t].remove((r, h))
  58. # def _search_worker(
  59. # seed: int,
  60. # triples_init: List[Tuple[int, int, int]],
  61. # triples_factory,
  62. # n_ent: int,
  63. # n_rel: int,
  64. # depth: int,
  65. # lo: float,
  66. # hi: float,
  67. # max_iters: int,
  68. # ) -> List[Tuple[int, int, int]] | None:
  69. # """
  70. # Run the exact same hill‑climb that `tune_crec()` did,
  71. # but entirely in this process. Return the edited triples
  72. # once c falls in [lo, hi]; return None if we used up all iterations.
  73. # """
  74. # rng = random.Random(seed)
  75. # triples = triples_init.copy().tolist()
  76. # for it in range(max_iters):
  77. # # WL‑CREC is *only* recomputed every 1000 edits, exactly like before
  78. # if it % 1000 == 0:
  79. # dataset = KGDataset(
  80. # triples_factory,
  81. # triples=torch.tensor(triples, dtype=torch.long),
  82. # num_entities=n_ent,
  83. # num_relations=n_rel,
  84. # )
  85. # wl = WLCREC(dataset)
  86. # colours = wl.wl_colours(depth)
  87. # *_, c, _ = wl.compute(H=5) # unchanged API
  88. # if lo <= c <= hi: # success
  89. # logger.info("[seed %d] hit %.4f in %d edits", seed, c, it)
  90. # return triples
  91. # # ---------- identical edit logic ----------
  92. # if c < lo:
  93. # if rng.random() < 0.5:
  94. # rem_det_edge(triples, colours, depth, rng)
  95. # else:
  96. # add_div_edge(triples, colours, depth, n_ent, n_rel, rng)
  97. # else: # c > hi
  98. # if rng.random() < 0.5:
  99. # rem_div_edge(triples, colours, depth, rng)
  100. # else:
  101. # add_det_edge(triples, colours, depth, n_rel, rng)
  102. # # ------------------------------------------
  103. # return None # used up our budget
  104. # # --------------------------------------------------------------------
  105. # # Unified tuner
  106. # # --------------------------------------------------------------------
  107. # def tune_crec(
  108. # wl_crec: WLCREC,
  109. # n_ent: int,
  110. # n_rel: int,
  111. # depth: int,
  112. # lo: float,
  113. # hi: float,
  114. # max_iters: int = 80_000,
  115. # seed: int = 42,
  116. # ):
  117. # triples = wl_crec.dataset.triples.tolist()
  118. # rng = random.Random(seed)
  119. # for it in range(max_iters):
  120. # print(f"\r[iter {it + 1:5d}] ", end="")
  121. # if it % 1000 == 0:
  122. # dataset = KGDataset(
  123. # wl_crec.dataset.triples_factory,
  124. # triples=torch.tensor(triples, dtype=torch.long),
  125. # num_entities=n_ent,
  126. # num_relations=n_rel,
  127. # )
  128. # tmp_wl_crec = WLCREC(dataset)
  129. # # colours = wl_colours(triples, n_ent, depth)
  130. # colours = tmp_wl_crec.wl_colours(depth)
  131. # _, _, _, _, c, _ = tmp_wl_crec.compute(H=5)
  132. # if lo <= c <= hi:
  133. # logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples))
  134. # return triples
  135. # if c < lo:
  136. # # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying
  137. # if rng.random() < 0.5:
  138. # rem_det_edge(triples, colours, depth, rng)
  139. # else:
  140. # add_div_edge(triples, colours, depth, n_ent, n_rel, rng)
  141. # # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic
  142. # elif rng.random() < 0.5:
  143. # rem_div_edge(triples, colours, depth, rng)
  144. # else:
  145. # add_det_edge(triples, colours, depth, n_rel, rng)
  146. # if (it + 1) % 10000 == 1:
  147. # logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples))
  148. # raise RuntimeError("Exceeded max iterations without hitting target band.")
  149. # def _edit_batch(
  150. # worker_id: int,
  151. # n_edits: int,
  152. # colours: List[int],
  153. # crec: float,
  154. # target_lo: float,
  155. # target_hi: float,
  156. # depth: int,
  157. # n_ent: int,
  158. # n_rel: int,
  159. # seed_base: int,
  160. # ):
  161. # """
  162. # Perform `n_edits` topology modifications *locally* and return
  163. # (added_triples, removed_triples) lists.
  164. # """
  165. # rng = random.Random(seed_base + worker_id)
  166. # local_added, local_removed = [], []
  167. # # local graph views: only sizes matter for edit selection
  168. # for _ in range(n_edits):
  169. # if crec < target_lo: # need ↑ WL-CREC
  170. # if rng.random() < 0.5: # • remove deterministic
  171. # # We cannot remove by index safely without the global list;
  172. # # choose a *signature* and remember intention to delete:
  173. # local_removed.append(("det", rng.random()))
  174. # else: # • add diversifying
  175. # h = rng.randrange(n_ent)
  176. # t = rng.randrange(n_ent)
  177. # while colours[h] == colours[t]:
  178. # h = rng.randrange(n_ent)
  179. # t = rng.randrange(n_ent)
  180. # r = rng.randrange(n_rel)
  181. # local_added.append((h, r, t))
  182. # else: # need ↓ WL-CREC
  183. # if rng.random() < 0.5: # • remove diversifying
  184. # local_removed.append(("div", rng.random()))
  185. # else: # • add deterministic
  186. # σ = rng.choice(colours)
  187. # τ = rng.choice(colours)
  188. # h = colours.index(σ)
  189. # t = colours.index(τ)
  190. # local_added.append((h, 0, t))
  191. # return local_added, local_removed
  192. # def wl_colours(triples, n_ent, depth):
  193. # # build adjacency once
  194. # out_adj, in_adj = defaultdict(list), defaultdict(list)
  195. # for h, r, t in triples:
  196. # out_adj[h].append((r, t))
  197. # in_adj[t].append((r, h))
  198. # colours_rounds = [[0] * n_ent] # round-0 colours
  199. # for h in range(1, depth + 1):
  200. # prev = colours_rounds[-1]
  201. # # 1) build textual signatures in parallel
  202. # def sig(v):
  203. # neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
  204. # ("↑", r, prev[u]) for r, u in in_adj.get(v, [])
  205. # ]
  206. # neigh.sort()
  207. # return (prev[v], tuple(neigh))
  208. # with ThreadPoolExecutor() as tpe: # cheap threads inside worker
  209. # sigs = list(tpe.map(sig, range(n_ent)))
  210. # # 2) assign deterministic colour IDs
  211. # sig2id: Dict[Tuple, int] = {}
  212. # next_round = [0] * n_ent
  213. # fresh = 0
  214. # for v, sg in enumerate(sigs):
  215. # cid = sig2id.setdefault(sg, fresh)
  216. # if cid == fresh:
  217. # fresh += 1
  218. # next_round[v] = cid
  219. # colours_rounds.append(next_round)
  220. # depth_colours = colours_rounds[-1]
  221. # return depth_colours
  222. # def _metric_worker(args):
  223. # triples, n_ent, n_rel, depth = args
  224. # dataset = KGDataset(
  225. # triples_factory=None, # Not used in this context
  226. # triples=torch.tensor(triples, dtype=torch.long),
  227. # num_entities=n_ent,
  228. # num_relations=n_rel,
  229. # )
  230. # wl_crec = WLCREC(dataset)
  231. # _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False)
  232. # colours = wl_colours(triples, n_ent, depth)
  233. # return c, colours
  234. # def tune_crec_parallel_edits(
  235. # triples_init,
  236. # n_ent: int,
  237. # n_rel: int,
  238. # depth: int,
  239. # target_lo: float,
  240. # target_hi: float,
  241. # max_iters: int = 80_000,
  242. # metric_every: int = 100,
  243. # n_workers: int = max(20, math.ceil(os.cpu_count() / 2)),
  244. # seed: int = 42,
  245. # ):
  246. # # -------- shared mutable state (main thread owns it) --------
  247. # triples = triples_init.tolist()
  248. # out_adj, in_adj = defaultdict(list), defaultdict(list)
  249. # for h, r, t in triples:
  250. # out_adj[h].append((r, t))
  251. # in_adj[t].append((r, h))
  252. # pool = ThreadPoolExecutor(max_workers=n_workers)
  253. # metric_lock = threading.Lock() # exactly one metric at a time
  254. # rng_global = random.Random(seed)
  255. # # ----- first metric checkpoint -----
  256. # crec, colours = _metric_worker((triples, n_ent, n_rel, depth))
  257. # edit_budget_total = 0
  258. # for it in range(0, max_iters, metric_every):
  259. # # =========================================================
  260. # # 1. PARALLEL EDIT STAGE (metric_every edits in total)
  261. # # =========================================================
  262. # futures = []
  263. # edits_per_worker = metric_every // n_workers
  264. # extra = metric_every % n_workers
  265. # for wid in range(n_workers):
  266. # n_edits = edits_per_worker + (1 if wid < extra else 0)
  267. # futures.append(
  268. # pool.submit(
  269. # _edit_batch,
  270. # wid,
  271. # n_edits,
  272. # colours,
  273. # crec,
  274. # target_lo,
  275. # target_hi,
  276. # depth,
  277. # n_ent,
  278. # n_rel,
  279. # seed,
  280. # )
  281. # )
  282. # # merge when workers finish
  283. # for fut in as_completed(futures):
  284. # added, removed_specs = fut.result()
  285. # # --- apply additions immediately (cheap, conflict-free)
  286. # for h, r, t in added:
  287. # triples.append((h, r, t))
  288. # out_adj[h].append((r, t))
  289. # in_adj[t].append((r, h))
  290. # # --- apply removals: interpret the spec on *current* graph
  291. # for typ, randv in removed_specs:
  292. # if typ == "div":
  293. # idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
  294. # else: # 'det'
  295. # sig_rel = defaultdict(set)
  296. # for h, r, t in triples:
  297. # sig_rel[(colours[h], colours[t])].add(r)
  298. # idxs = [
  299. # i
  300. # for i, (h, r, t) in enumerate(triples)
  301. # if len(sig_rel[(colours[h], colours[t])]) == 1
  302. # ]
  303. # if idxs:
  304. # victim = idxs[int(randv * len(idxs))]
  305. # h, r, t = triples.pop(victim)
  306. # out_adj[h].remove((r, t))
  307. # in_adj[t].remove((r, h))
  308. # edit_budget_total += metric_every
  309. # # =========================================================
  310. # # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised)
  311. # # =========================================================
  312. # # if edit_budget_total % (10 * metric_every) == 0:
  313. # if True:
  314. # with metric_lock:
  315. # crec, colours = _metric_worker((triples, n_ent, n_rel, depth))
  316. # logging.info(
  317. # "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples)
  318. # )
  319. # if target_lo <= crec <= target_hi:
  320. # logging.info("Target band reached.")
  321. # pool.shutdown(wait=True)
  322. # return triples
  323. # pool.shutdown(wait=True)
  324. # raise RuntimeError("Exceeded max iteration budget without success.")
  325. # from __future__ import annotations
  326. # import logging
  327. # import random
  328. # from collections import defaultdict
  329. # from concurrent.futures import ThreadPoolExecutor, as_completed
  330. # from typing import List, Tuple
  331. # import torch
  332. # from data.kg_dataset import KGDataset
  333. # from metrics.wlcrec import WLCREC
  334. # from tools import get_pretty_logger
  335. # logger = get_pretty_logger(__name__)
  336. # # ------------ additions ------------
  337. # def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random):
  338. # for _ in range(1000):
  339. # h, t = rng.randrange(n_ent), rng.randrange(n_ent)
  340. # if colours[h] != colours[t]:
  341. # r = rng.randrange(n_rel)
  342. # return ("add", (h, r, t))
  343. # return None # fell through – extremely rare
  344. # def propose_add_det(colours, n_rel: int, rng: random.Random):
  345. # σ, τ = rng.choice(colours), rng.choice(colours)
  346. # h, t = colours.index(σ), colours.index(τ)
  347. # return ("add", (h, 0, t)) # rel 0 = deterministic
  348. # # ------------ removals ------------
  349. # def propose_rem_div(triples: List, colours, rng: random.Random):
  350. # cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]]
  351. # return ("rem", rng.choice(cand)) if cand else None
  352. # def propose_rem_det(triples: List, colours, rng: random.Random):
  353. # sig_rel = defaultdict(set)
  354. # for h, r, t in triples:
  355. # sig_rel[(colours[h], colours[t])].add(r)
  356. # cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1]
  357. # return ("rem", rng.choice(cand)) if cand else None
  358. # def make_edit_proposal(
  359. # triples_snapshot: List,
  360. # colours_snapshot,
  361. # c: float,
  362. # lo: float,
  363. # hi: float,
  364. # n_ent: int,
  365. # n_rel: int,
  366. # depth: int, # still here for future use / signatures
  367. # seed: int,
  368. # ):
  369. # """Return exactly one ('add' | 'rem', triple) proposal or None."""
  370. # rng = random.Random(seed)
  371. # # -- decide which kind of edit we want, *given* the current c ------------
  372. # if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying)
  373. # chooser = (propose_rem_det, propose_add_div)
  374. # else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic)
  375. # chooser = (propose_rem_div, propose_add_det)
  376. # op = rng.choice(chooser)
  377. # return (
  378. # op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng)
  379. # if op.__name__.startswith("propose_add")
  380. # else op(triples_snapshot, colours_snapshot, rng)
  381. # )
  382. # def tune_crec_parallel_edits(
  383. # triples: List,
  384. # n_ent: int,
  385. # n_rel: int,
  386. # depth: int,
  387. # lo: float,
  388. # hi: float,
  389. # *,
  390. # max_iters: int = 80_000,
  391. # edits_per_eval: int = 1000, # == old “if it % 1000 == 0”
  392. # batch_size: int = 256, # how many proposals we farm out at once
  393. # max_workers: int = 4,
  394. # seed: int = 42,
  395. # ) -> List:
  396. # rng_global = random.Random(seed)
  397. # triples = set(triples) # deduplicate if needed
  398. # with ThreadPoolExecutor(max_workers=max_workers) as pool:
  399. # proposal_seed = seed * 997 # deterministic but different stream
  400. # # edit_counter = 0
  401. # # while edit_counter < max_iters:
  402. # # # ----------------- expensive part (single‑thread) ----------------
  403. # # dataset = KGDataset(
  404. # # triples_factory=None, # Not used in this context
  405. # # triples=torch.tensor(triples, dtype=torch.long),
  406. # # num_entities=n_ent,
  407. # # num_relations=n_rel,
  408. # # )
  409. # # tmp = WLCREC(dataset)
  410. # # colours = tmp.wl_colours(depth)
  411. # # *_, c, _ = tmp.compute(H=5)
  412. # # # -----------------------------------------------------------------
  413. # # if lo <= c <= hi:
  414. # # logger.info(
  415. # # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples)
  416. # # )
  417. # # return triples
  418. # # ============ parallel block: just make `edits_per_eval` proposals
  419. # needed = min(edits_per_eval, max_iters - edit_counter)
  420. # proposals = []
  421. # while len(proposals) < needed:
  422. # # launch a batch of workers
  423. # futs = [
  424. # pool.submit(
  425. # make_edit_proposal,
  426. # triples,
  427. # colours,
  428. # c,
  429. # lo,
  430. # hi,
  431. # n_ent,
  432. # n_rel,
  433. # depth,
  434. # proposal_seed + i,
  435. # )
  436. # for i in range(batch_size)
  437. # ]
  438. # for f in as_completed(futs):
  439. # prop = f.result()
  440. # if prop is not None:
  441. # proposals.append(prop)
  442. # if len(proposals) == needed:
  443. # break
  444. # proposal_seed += batch_size # move RNG window forward
  445. # # -------------- apply the gathered proposals *sequentially* -------
  446. # for kind, trp in proposals:
  447. # if kind == "add":
  448. # triples.append(trp)
  449. # else: # "rem"
  450. # try:
  451. # triples.remove(trp)
  452. # except ValueError:
  453. # pass # already gone – benign collision
  454. # # -----------------------------------------------------------------
  455. # edit_counter += needed
  456. # if edit_counter % 1_000 == 0:
  457. # logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples))
  458. # raise RuntimeError("Exceeded max_iters without hitting target band.")
  459. import os
  460. import random
  461. import threading
  462. from collections import defaultdict
  463. from typing import Dict, List, Tuple
  464. # --------------------------------------------------------------------
  465. # Logging helper (replace with your own if you prefer) --------------
  466. # --------------------------------------------------------------------
  467. from tools import get_pretty_logger
  468. logging = get_pretty_logger(__name__)
  469. # --------------------------------------------------------------------
  470. # Original edge‑editing primitives (unchanged) ----------------------
  471. # --------------------------------------------------------------------
  472. def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng):
  473. for _ in range(1000):
  474. h = rng.randrange(n_ent)
  475. t = rng.randrange(n_ent)
  476. if colours[h] != colours[t]:
  477. r = rng.randrange(n_rel)
  478. triples.append((h, r, t))
  479. out_adj[h].append((r, t))
  480. in_adj[t].append((r, h))
  481. return
  482. def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng):
  483. σ = rng.choice(colours)
  484. τ = rng.choice(colours)
  485. h = colours.index(σ)
  486. t = colours.index(τ)
  487. triples.append((h, 0, t))
  488. out_adj[h].append((0, t))
  489. in_adj[t].append((0, h))
  490. def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng):
  491. cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
  492. if cand:
  493. idx = rng.choice(cand)
  494. h, r, t = triples.pop(idx)
  495. out_adj[h].remove((r, t))
  496. in_adj[t].remove((r, h))
  497. def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng):
  498. sig_rel = defaultdict(set)
  499. for h, r, t in triples:
  500. σ, τ = colours[h], colours[t]
  501. sig_rel[(σ, τ)].add(r)
  502. cand = []
  503. for i, (h, r, t) in enumerate(triples):
  504. if len(sig_rel[(colours[h], colours[t])]) == 1:
  505. cand.append(i)
  506. if cand:
  507. idx = rng.choice(cand)
  508. h, r, t = triples.pop(idx)
  509. out_adj[h].remove((r, t))
  510. in_adj[t].remove((r, h))
  511. def _worker(
  512. *,
  513. worker_id: int,
  514. rng: random.Random,
  515. triples: List[Tuple[int, int, int]],
  516. out_adj: Dict[int, List[Tuple[int, int]]],
  517. in_adj: Dict[int, List[Tuple[int, int]]],
  518. colours: List[int],
  519. n_ent: int,
  520. n_rel: int,
  521. depth: int,
  522. c: float,
  523. lo: float,
  524. hi: float,
  525. max_iters: int,
  526. state_lock: threading.Lock,
  527. stop_event: threading.Event,
  528. ):
  529. """One thread: mutate the *shared* structures until success/stop."""
  530. for it in range(max_iters):
  531. if stop_event.is_set():
  532. return # someone else finished ─ exit early
  533. with state_lock: # protect the shared graph
  534. # if lo <= c <= hi:
  535. # logging.info(
  536. # "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)",
  537. # worker_id,
  538. # it,
  539. # c,
  540. # len(triples),
  541. # )
  542. # stop_event.set()
  543. # return
  544. # Choose and apply one edit -----------------------------
  545. if c < lo: # need ↑ CREC
  546. if rng.random() < 0.5:
  547. rem_det_edge(triples, out_adj, in_adj, colours, depth, rng)
  548. else:
  549. add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng)
  550. elif rng.random() < 0.5:
  551. rem_div_edge(triples, out_adj, in_adj, colours, depth, rng)
  552. else:
  553. add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng)
  554. logging.warning("[worker %d] reached max_iters", worker_id)
  555. stop_event.set()
  556. return
  557. # --------------------------------------------------------------------
  558. # Public API --------------------------------------------------------
  559. # --------------------------------------------------------------------
  560. def fast_tune_crec(
  561. triples: List,
  562. colours: List,
  563. n_ent: int,
  564. n_rel: int,
  565. depth: int,
  566. c: float,
  567. lo: float,
  568. hi: float,
  569. max_iters: int = 1000,
  570. max_workers: int | None = None,
  571. seeds: List[int] | None = None,
  572. ) -> List[Tuple[int, int, int]]:
  573. """Tune WL‑CREC with *shared* triples using multiple threads.
  574. Returns the **same list instance** that was passed in – already
  575. modified in place by the winning thread.
  576. """
  577. if max_workers is None:
  578. # max_workers = os.cpu_count() or 4
  579. max_workers = 4
  580. if seeds is None:
  581. seeds = [42 + i for i in range(max_workers)]
  582. assert len(seeds) >= max_workers, "Need at least one seed per worker"
  583. # Prepare adjacency once (shared) --------------------------------
  584. out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
  585. in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
  586. for h, r, t in triples:
  587. out_adj[h].append((r, t))
  588. in_adj[t].append((r, h))
  589. state_lock = threading.Lock()
  590. stop_event = threading.Event()
  591. logging.info(
  592. "Launching %d threads on shared triples (target %.3f–%.3f)",
  593. max_workers,
  594. lo,
  595. hi,
  596. )
  597. threads: List[threading.Thread] = []
  598. for wid in range(max_workers):
  599. t = threading.Thread(
  600. name=f"tune‑crec‑worker‑{wid}",
  601. target=_worker,
  602. kwargs=dict(
  603. worker_id=wid,
  604. rng=random.Random(seeds[wid]),
  605. triples=triples,
  606. out_adj=out_adj,
  607. in_adj=in_adj,
  608. colours=colours,
  609. n_ent=n_ent,
  610. n_rel=n_rel,
  611. depth=depth,
  612. c=c,
  613. lo=lo,
  614. hi=hi,
  615. max_iters=max_iters,
  616. state_lock=state_lock,
  617. stop_event=stop_event,
  618. ),
  619. daemon=False,
  620. )
  621. threads.append(t)
  622. t.start()
  623. for t in threads:
  624. t.join()
  625. if not stop_event.is_set():
  626. raise RuntimeError("No thread converged – try increasing max_iters or widen band")
  627. return triples