from __future__ import annotations import logging import random from collections import defaultdict, deque from math import isclose from typing import List import torch from data.kg_dataset import KGDataset from metrics.wlcrec import WLCREC from tools import get_pretty_logger logging = get_pretty_logger(__name__) # -------------------------------------------------------------------- # Radius-k patch sampler # -------------------------------------------------------------------- def radius_patch_sample( all_triples: List, ratio: float, radius: int, rng: random.Random, ) -> List: """ Take ~ratio fraction of triples by picking random seeds and keeping *all* triples within ≤ radius hops (undirected) of them. """ n_total = len(all_triples) target = int(ratio * n_total) # Build entity → triple-indices adjacency for BFS ent2trip = defaultdict(list) for idx, (h, _, t) in enumerate(all_triples): ent2trip[h].append(idx) ent2trip[t].append(idx) chosen_idx: set[int] = set() visited_ent = set() while len(chosen_idx) < target: seed_idx = rng.randrange(n_total) if seed_idx in chosen_idx: continue # BFS over entities queue = deque() for e in all_triples[seed_idx][::2]: # subject & object if e not in visited_ent: queue.append((e, 0)) visited_ent.add(e) while queue: ent, dist = queue.popleft() for tidx in ent2trip[ent]: if tidx not in chosen_idx: chosen_idx.add(tidx) if dist < radius: for tidx in ent2trip[ent]: h, _, t = all_triples[tidx] for nb in (h, t): if nb not in visited_ent: visited_ent.add(nb) queue.append((nb, dist + 1)) # guard against infinite loop on very small ratio if isclose(ratio, 0.0) and chosen_idx: break return [all_triples[i] for i in chosen_idx] # -------------------------------------------------------------------- # Main driver: repeated sampling until WL-CREC band hit # -------------------------------------------------------------------- def find_sample( all_triples: List, n_ent: int, n_rel: int, depth: int, ratio: float, radius: int, lower: float, upper: float, max_tries: int, seed: int, ) -> List: for trial in range(1, max_tries + 1): rng = random.Random(seed + trial) # fresh randomness each try sample = radius_patch_sample(all_triples, ratio, radius, rng) sample_ds = KGDataset( triples_factory=None, triples=torch.tensor(sample, dtype=torch.long), num_entities=n_ent, num_relations=n_rel, split="train", ) wl_crec = WLCREC(sample_ds) crec_val = wl_crec.compute(depth)[-2] logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val) if lower <= crec_val <= upper: logging.info("Success after %d tries.", trial) return sample raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.")