You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

crec_radius_sample.py 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. import logging
  3. import random
  4. from collections import defaultdict, deque
  5. from math import isclose
  6. from typing import List
  7. import torch
  8. from data.kg_dataset import KGDataset
  9. from metrics.wlcrec import WLCREC
  10. from tools import get_pretty_logger
  11. logging = get_pretty_logger(__name__)
  12. # --------------------------------------------------------------------
  13. # Radius-k patch sampler
  14. # --------------------------------------------------------------------
  15. def radius_patch_sample(
  16. all_triples: List,
  17. ratio: float,
  18. radius: int,
  19. rng: random.Random,
  20. ) -> List:
  21. """
  22. Take ~ratio fraction of triples by picking random seeds and
  23. keeping *all* triples within ≤ radius hops (undirected) of them.
  24. """
  25. n_total = len(all_triples)
  26. target = int(ratio * n_total)
  27. # Build entity → triple-indices adjacency for BFS
  28. ent2trip = defaultdict(list)
  29. for idx, (h, _, t) in enumerate(all_triples):
  30. ent2trip[h].append(idx)
  31. ent2trip[t].append(idx)
  32. chosen_idx: set[int] = set()
  33. visited_ent = set()
  34. while len(chosen_idx) < target:
  35. seed_idx = rng.randrange(n_total)
  36. if seed_idx in chosen_idx:
  37. continue
  38. # BFS over entities
  39. queue = deque()
  40. for e in all_triples[seed_idx][::2]: # subject & object
  41. if e not in visited_ent:
  42. queue.append((e, 0))
  43. visited_ent.add(e)
  44. while queue:
  45. ent, dist = queue.popleft()
  46. for tidx in ent2trip[ent]:
  47. if tidx not in chosen_idx:
  48. chosen_idx.add(tidx)
  49. if dist < radius:
  50. for tidx in ent2trip[ent]:
  51. h, _, t = all_triples[tidx]
  52. for nb in (h, t):
  53. if nb not in visited_ent:
  54. visited_ent.add(nb)
  55. queue.append((nb, dist + 1))
  56. # guard against infinite loop on very small ratio
  57. if isclose(ratio, 0.0) and chosen_idx:
  58. break
  59. return [all_triples[i] for i in chosen_idx]
  60. # --------------------------------------------------------------------
  61. # Main driver: repeated sampling until WL-CREC band hit
  62. # --------------------------------------------------------------------
  63. def find_sample(
  64. all_triples: List,
  65. n_ent: int,
  66. n_rel: int,
  67. depth: int,
  68. ratio: float,
  69. radius: int,
  70. lower: float,
  71. upper: float,
  72. max_tries: int,
  73. seed: int,
  74. ) -> List:
  75. for trial in range(1, max_tries + 1):
  76. rng = random.Random(seed + trial) # fresh randomness each try
  77. sample = radius_patch_sample(all_triples, ratio, radius, rng)
  78. sample_ds = KGDataset(
  79. triples_factory=None,
  80. triples=torch.tensor(sample, dtype=torch.long),
  81. num_entities=n_ent,
  82. num_relations=n_rel,
  83. split="train",
  84. )
  85. wl_crec = WLCREC(sample_ds)
  86. crec_val = wl_crec.compute(depth)[-2]
  87. logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val)
  88. if lower <= crec_val <= upper:
  89. logging.info("Success after %d tries.", trial)
  90. return sample
  91. raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.")