123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from __future__ import annotations
-
- import argparse
- import logging
-
- import torch
-
- from data import YAGO310Dataset
-
- # from metrics.crec_modifier import tune_crec_parallel_edits
- from data.kg_dataset import KGDataset
- from metrics.crec_radius_sample import find_sample
- from metrics.wlcrec import WLCREC
- from tools import get_pretty_logger
-
- logging = get_pretty_logger(__name__)
-
-
- # --------------------------------------------------------------------
- # CLI
- # --------------------------------------------------------------------
- def main():
- p = argparse.ArgumentParser()
- p.add_argument("--depth", type=int, default=5)
- p.add_argument("--lower", type=float, default=0.25)
- p.add_argument("--upper", type=float, default=0.26)
- p.add_argument("--seed", type=int, default=0)
- p.add_argument("--save", type=str, default="fb15k_crec_bi.pt")
- args = p.parse_args()
-
- # ds = FB15KDataset()
- ds = YAGO310Dataset()
- triples = ds.triples
- n_ent, n_rel = ds.num_entities, ds.num_relations
- logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples))
-
- # wlcrec = WLCREC(ds)
- # c = wlcrec.compute(5)[-2] # C_NWLEC
- # colours = wlcrec.wl_colours(args.depth)[-1]
-
- # tuned = tune_crec(
- # wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed
- # )
-
- # tuned = tune_crec_parallel_edits(
- # triples,
- # n_ent,
- # n_rel,
- # depth=args.depth,
- # lo=args.lower,
- # hi=args.upper,
- # seed=args.seed,
- # )
-
- # tuned = fast_tune_crec(
- # triples.tolist(),
- # colours,
- # n_ent,
- # n_rel,
- # depth=args.depth,
- # c=c,
- # lo=args.lower,
- # hi=args.upper,
- # max_iters=1000,
- # max_workers=10, # Adjust number of workers as needed
- # )
-
- tuned = find_sample(
- triples.tolist(),
- n_ent,
- n_rel,
- depth=args.depth,
- ratio=0.1, # Adjust ratio as needed
- radius=2, # Adjust radius as needed
- lower=args.lower,
- upper=args.upper,
- max_tries=1000,
- seed=58,
- )
-
- tuned_ds = KGDataset(
- triples_factory=None,
- triples=torch.tensor(tuned, dtype=torch.long),
- num_entities=n_ent,
- num_relations=n_rel,
- split="train",
- )
-
- tuned_wlcrec = WLCREC(tuned_ds)
- entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5)
- print(
- f"\nTuned CREC results (H={5})",
- f"\n • avg WLEC : {entropy:.6f} nats",
- f"\n • C_ratio : {c_ratio:.6f}",
- f"\n • C_NWLEC : {c_nwlec:.6f}",
- f"\n • H_cond(R|S_H) : {h_cond:.6f} nats",
- f"\n • D_ratio : {d_ratio:.6f}",
- f"\n • D_NWLEC : {d_nwlec:.6f}",
- )
-
- torch.save(torch.tensor(tuned, dtype=torch.long), args.save)
- logging.info("Saved tuned triples → %s", args.save)
-
-
- if __name__ == "__main__":
- main()
|