from __future__ import annotations import argparse import logging import torch from data import YAGO310Dataset # from metrics.crec_modifier import tune_crec_parallel_edits from data.kg_dataset import KGDataset from metrics.crec_radius_sample import find_sample from metrics.wlcrec import WLCREC from tools import get_pretty_logger logging = get_pretty_logger(__name__) # -------------------------------------------------------------------- # CLI # -------------------------------------------------------------------- def main(): p = argparse.ArgumentParser() p.add_argument("--depth", type=int, default=5) p.add_argument("--lower", type=float, default=0.25) p.add_argument("--upper", type=float, default=0.26) p.add_argument("--seed", type=int, default=0) p.add_argument("--save", type=str, default="fb15k_crec_bi.pt") args = p.parse_args() # ds = FB15KDataset() ds = YAGO310Dataset() triples = ds.triples n_ent, n_rel = ds.num_entities, ds.num_relations logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples)) # wlcrec = WLCREC(ds) # c = wlcrec.compute(5)[-2] # C_NWLEC # colours = wlcrec.wl_colours(args.depth)[-1] # tuned = tune_crec( # wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed # ) # tuned = tune_crec_parallel_edits( # triples, # n_ent, # n_rel, # depth=args.depth, # lo=args.lower, # hi=args.upper, # seed=args.seed, # ) # tuned = fast_tune_crec( # triples.tolist(), # colours, # n_ent, # n_rel, # depth=args.depth, # c=c, # lo=args.lower, # hi=args.upper, # max_iters=1000, # max_workers=10, # Adjust number of workers as needed # ) tuned = find_sample( triples.tolist(), n_ent, n_rel, depth=args.depth, ratio=0.1, # Adjust ratio as needed radius=2, # Adjust radius as needed lower=args.lower, upper=args.upper, max_tries=1000, seed=58, ) tuned_ds = KGDataset( triples_factory=None, triples=torch.tensor(tuned, dtype=torch.long), num_entities=n_ent, num_relations=n_rel, split="train", ) tuned_wlcrec = WLCREC(tuned_ds) entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5) print( f"\nTuned CREC results (H={5})", f"\n • avg WLEC : {entropy:.6f} nats", f"\n • C_ratio : {c_ratio:.6f}", f"\n • C_NWLEC : {c_nwlec:.6f}", f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", f"\n • D_ratio : {d_ratio:.6f}", f"\n • D_NWLEC : {d_nwlec:.6f}", ) torch.save(torch.tensor(tuned, dtype=torch.long), args.save) logging.info("Saved tuned triples → %s", args.save) if __name__ == "__main__": main()