123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from __future__ import annotations
-
- import logging
- import random
- from collections import defaultdict, deque
- from math import isclose
- from typing import List
-
- import torch
-
- from data.kg_dataset import KGDataset
- from metrics.wlcrec import WLCREC
- from tools import get_pretty_logger
-
- logging = get_pretty_logger(__name__)
-
-
-
- # --------------------------------------------------------------------
- # Radius-k patch sampler
- # --------------------------------------------------------------------
- def radius_patch_sample(
- all_triples: List,
- ratio: float,
- radius: int,
- rng: random.Random,
- ) -> List:
- """
- Take ~ratio fraction of triples by picking random seeds and
- keeping *all* triples within ≤ radius hops (undirected) of them.
- """
- n_total = len(all_triples)
- target = int(ratio * n_total)
-
- # Build entity → triple-indices adjacency for BFS
- ent2trip = defaultdict(list)
- for idx, (h, _, t) in enumerate(all_triples):
- ent2trip[h].append(idx)
- ent2trip[t].append(idx)
-
- chosen_idx: set[int] = set()
- visited_ent = set()
-
- while len(chosen_idx) < target:
- seed_idx = rng.randrange(n_total)
- if seed_idx in chosen_idx:
- continue
- # BFS over entities
- queue = deque()
- for e in all_triples[seed_idx][::2]: # subject & object
- if e not in visited_ent:
- queue.append((e, 0))
- visited_ent.add(e)
- while queue:
- ent, dist = queue.popleft()
- for tidx in ent2trip[ent]:
- if tidx not in chosen_idx:
- chosen_idx.add(tidx)
- if dist < radius:
- for tidx in ent2trip[ent]:
- h, _, t = all_triples[tidx]
- for nb in (h, t):
- if nb not in visited_ent:
- visited_ent.add(nb)
- queue.append((nb, dist + 1))
- # guard against infinite loop on very small ratio
- if isclose(ratio, 0.0) and chosen_idx:
- break
-
- return [all_triples[i] for i in chosen_idx]
-
-
- # --------------------------------------------------------------------
- # Main driver: repeated sampling until WL-CREC band hit
- # --------------------------------------------------------------------
- def find_sample(
- all_triples: List,
- n_ent: int,
- n_rel: int,
- depth: int,
- ratio: float,
- radius: int,
- lower: float,
- upper: float,
- max_tries: int,
- seed: int,
- ) -> List:
- for trial in range(1, max_tries + 1):
- rng = random.Random(seed + trial) # fresh randomness each try
- sample = radius_patch_sample(all_triples, ratio, radius, rng)
-
- sample_ds = KGDataset(
- triples_factory=None,
- triples=torch.tensor(sample, dtype=torch.long),
- num_entities=n_ent,
- num_relations=n_rel,
- split="train",
- )
- wl_crec = WLCREC(sample_ds)
-
- crec_val = wl_crec.compute(depth)[-2]
- logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val)
- if lower <= crec_val <= upper:
- logging.info("Success after %d tries.", trial)
- return sample
- raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.")
|