You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

build_crec_datasets.py 2.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. import argparse
  3. import logging
  4. import torch
  5. from data import YAGO310Dataset
  6. # from metrics.crec_modifier import tune_crec_parallel_edits
  7. from data.kg_dataset import KGDataset
  8. from metrics.crec_radius_sample import find_sample
  9. from metrics.wlcrec import WLCREC
  10. from tools import get_pretty_logger
  11. logging = get_pretty_logger(__name__)
  12. # --------------------------------------------------------------------
  13. # CLI
  14. # --------------------------------------------------------------------
  15. def main():
  16. p = argparse.ArgumentParser()
  17. p.add_argument("--depth", type=int, default=5)
  18. p.add_argument("--lower", type=float, default=0.25)
  19. p.add_argument("--upper", type=float, default=0.26)
  20. p.add_argument("--seed", type=int, default=0)
  21. p.add_argument("--save", type=str, default="fb15k_crec_bi.pt")
  22. args = p.parse_args()
  23. # ds = FB15KDataset()
  24. ds = YAGO310Dataset()
  25. triples = ds.triples
  26. n_ent, n_rel = ds.num_entities, ds.num_relations
  27. logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples))
  28. # wlcrec = WLCREC(ds)
  29. # c = wlcrec.compute(5)[-2] # C_NWLEC
  30. # colours = wlcrec.wl_colours(args.depth)[-1]
  31. # tuned = tune_crec(
  32. # wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed
  33. # )
  34. # tuned = tune_crec_parallel_edits(
  35. # triples,
  36. # n_ent,
  37. # n_rel,
  38. # depth=args.depth,
  39. # lo=args.lower,
  40. # hi=args.upper,
  41. # seed=args.seed,
  42. # )
  43. # tuned = fast_tune_crec(
  44. # triples.tolist(),
  45. # colours,
  46. # n_ent,
  47. # n_rel,
  48. # depth=args.depth,
  49. # c=c,
  50. # lo=args.lower,
  51. # hi=args.upper,
  52. # max_iters=1000,
  53. # max_workers=10, # Adjust number of workers as needed
  54. # )
  55. tuned = find_sample(
  56. triples.tolist(),
  57. n_ent,
  58. n_rel,
  59. depth=args.depth,
  60. ratio=0.1, # Adjust ratio as needed
  61. radius=2, # Adjust radius as needed
  62. lower=args.lower,
  63. upper=args.upper,
  64. max_tries=1000,
  65. seed=58,
  66. )
  67. tuned_ds = KGDataset(
  68. triples_factory=None,
  69. triples=torch.tensor(tuned, dtype=torch.long),
  70. num_entities=n_ent,
  71. num_relations=n_rel,
  72. split="train",
  73. )
  74. tuned_wlcrec = WLCREC(tuned_ds)
  75. entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5)
  76. print(
  77. f"\nTuned CREC results (H={5})",
  78. f"\n • avg WLEC : {entropy:.6f} nats",
  79. f"\n • C_ratio : {c_ratio:.6f}",
  80. f"\n • C_NWLEC : {c_nwlec:.6f}",
  81. f"\n • H_cond(R|S_H) : {h_cond:.6f} nats",
  82. f"\n • D_ratio : {d_ratio:.6f}",
  83. f"\n • D_NWLEC : {d_nwlec:.6f}",
  84. )
  85. torch.save(torch.tensor(tuned, dtype=torch.long), args.save)
  86. logging.info("Saved tuned triples → %s", args.save)
  87. if __name__ == "__main__":
  88. main()