| @@ -0,0 +1 @@ | |||
| # KGEvaluation | |||
| @@ -0,0 +1,106 @@ | |||
| from __future__ import annotations | |||
| import argparse | |||
| import logging | |||
| import torch | |||
| from data import YAGO310Dataset | |||
| # from metrics.crec_modifier import tune_crec_parallel_edits | |||
| from data.kg_dataset import KGDataset | |||
| from metrics.crec_radius_sample import find_sample | |||
| from metrics.wlcrec import WLCREC | |||
| from tools import get_pretty_logger | |||
| logging = get_pretty_logger(__name__) | |||
| # -------------------------------------------------------------------- | |||
| # CLI | |||
| # -------------------------------------------------------------------- | |||
| def main(): | |||
| p = argparse.ArgumentParser() | |||
| p.add_argument("--depth", type=int, default=5) | |||
| p.add_argument("--lower", type=float, default=0.25) | |||
| p.add_argument("--upper", type=float, default=0.26) | |||
| p.add_argument("--seed", type=int, default=0) | |||
| p.add_argument("--save", type=str, default="fb15k_crec_bi.pt") | |||
| args = p.parse_args() | |||
| # ds = FB15KDataset() | |||
| ds = YAGO310Dataset() | |||
| triples = ds.triples | |||
| n_ent, n_rel = ds.num_entities, ds.num_relations | |||
| logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples)) | |||
| # wlcrec = WLCREC(ds) | |||
| # c = wlcrec.compute(5)[-2] # C_NWLEC | |||
| # colours = wlcrec.wl_colours(args.depth)[-1] | |||
| # tuned = tune_crec( | |||
| # wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed | |||
| # ) | |||
| # tuned = tune_crec_parallel_edits( | |||
| # triples, | |||
| # n_ent, | |||
| # n_rel, | |||
| # depth=args.depth, | |||
| # lo=args.lower, | |||
| # hi=args.upper, | |||
| # seed=args.seed, | |||
| # ) | |||
| # tuned = fast_tune_crec( | |||
| # triples.tolist(), | |||
| # colours, | |||
| # n_ent, | |||
| # n_rel, | |||
| # depth=args.depth, | |||
| # c=c, | |||
| # lo=args.lower, | |||
| # hi=args.upper, | |||
| # max_iters=1000, | |||
| # max_workers=10, # Adjust number of workers as needed | |||
| # ) | |||
| tuned = find_sample( | |||
| triples.tolist(), | |||
| n_ent, | |||
| n_rel, | |||
| depth=args.depth, | |||
| ratio=0.1, # Adjust ratio as needed | |||
| radius=2, # Adjust radius as needed | |||
| lower=args.lower, | |||
| upper=args.upper, | |||
| max_tries=1000, | |||
| seed=58, | |||
| ) | |||
| tuned_ds = KGDataset( | |||
| triples_factory=None, | |||
| triples=torch.tensor(tuned, dtype=torch.long), | |||
| num_entities=n_ent, | |||
| num_relations=n_rel, | |||
| split="train", | |||
| ) | |||
| tuned_wlcrec = WLCREC(tuned_ds) | |||
| entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5) | |||
| print( | |||
| f"\nTuned CREC results (H={5})", | |||
| f"\n • avg WLEC : {entropy:.6f} nats", | |||
| f"\n • C_ratio : {c_ratio:.6f}", | |||
| f"\n • C_NWLEC : {c_nwlec:.6f}", | |||
| f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||
| f"\n • D_ratio : {d_ratio:.6f}", | |||
| f"\n • D_NWLEC : {d_nwlec:.6f}", | |||
| ) | |||
| torch.save(torch.tensor(tuned, dtype=torch.long), args.save) | |||
| logging.info("Saved tuned triples → %s", args.save) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,20 @@ | |||
| _target_: tools.params.CommonParams | |||
| model_name: "default_model" | |||
| project_name: "my_project" | |||
| run_name: ... | |||
| # Where to write all outputs (checkpoints / tensorboard logs / etc.) | |||
| save_dpath: "./checkpoints" | |||
| save_every: 200 | |||
| # If you want to resume from a checkpoint, set load_dpath to a string path; | |||
| # otherwise leave it null or empty. | |||
| load_path: null | |||
| # Default log directory | |||
| log_dir: "logs" | |||
| log_every: 10 | |||
| log_console_every: 10 | |||
| evaluate_only: false | |||
| @@ -0,0 +1,9 @@ | |||
| defaults: | |||
| - common: common | |||
| - model: model | |||
| - data: data | |||
| - training: training | |||
| hydra: | |||
| run: | |||
| dir: "./outputs" | |||
| @@ -0,0 +1,3 @@ | |||
| _target_: data.wn18rr_dataset.WN18RRDataset | |||
| split: train | |||
| @@ -0,0 +1,8 @@ | |||
| train: | |||
| _target_: data.fb15k.FB15KDataset | |||
| split: train | |||
| valid: | |||
| _target_: data.fb15k.FB15KDataset | |||
| split: valid | |||
| @@ -0,0 +1,8 @@ | |||
| train: | |||
| _target_: data.wn18.WN18Dataset | |||
| split: train | |||
| valid: | |||
| _target_: data.wn18.WN18Dataset | |||
| split: valid | |||
| @@ -0,0 +1,8 @@ | |||
| train: | |||
| _target_: data.wn18.WN18RRDataset | |||
| split: train | |||
| valid: | |||
| _target_: data.wn18.WN18RRDataset | |||
| split: valid | |||
| @@ -0,0 +1,9 @@ | |||
| train: | |||
| _target_: data.yago3_10.YAGO310Dataset | |||
| split: train | |||
| valid: | |||
| _target_: data.yago3_10.YAGO310Dataset | |||
| split: valid | |||
| @@ -0,0 +1,3 @@ | |||
| num_entities: 100 | |||
| num_relations: 100 | |||
| dim: 100 | |||
| @@ -0,0 +1,9 @@ | |||
| defaults: | |||
| - model | |||
| - _self_ | |||
| _target_: models.translation.trans_e.TransE | |||
| dim: 200 | |||
| sf_norm: 1 | |||
| p_norm: true | |||
| @@ -0,0 +1,11 @@ | |||
| defaults: | |||
| - model | |||
| - _self_ | |||
| _target_: models.translation.trans_h.TransH | |||
| dim: 200 | |||
| sf_norm: 1 | |||
| p_norm: true | |||
| p_norm_value: 2 | |||
| margin: 1.0 | |||
| @@ -0,0 +1,12 @@ | |||
| defaults: | |||
| - model | |||
| - _self_ | |||
| _target_: models.translation.trans_r.TransR | |||
| entity_dim: 200 | |||
| relation_dim: 30 | |||
| sf_norm: 1 | |||
| p_norm: true | |||
| p_norm_value: 2 | |||
| margin: 1.0 | |||
| @@ -0,0 +1,21 @@ | |||
| _target_: tools.params.TrainingParams | |||
| lr: 0.005 | |||
| min_lr: 0.0001 | |||
| weight_decay: 0.0 | |||
| t0: 50 | |||
| lr_step: 10 | |||
| gamma: 0.95 | |||
| batch_size: 8192 | |||
| num_workers: 3 | |||
| seed: 42 | |||
| num_train_steps: 10000 | |||
| eval_every: 100 | |||
| validation_sample_size: 1000 | |||
| validation_batch_size: 64 | |||
| model_trainer: | |||
| _target_: training.model_trainers.model_trainer_base.ModelTrainerBase | |||
| @@ -0,0 +1,11 @@ | |||
| defaults: | |||
| - training | |||
| - _self_ | |||
| model_trainer: | |||
| _target_: training.model_trainers.translation.trans_e_trainer.TransETrainer | |||
| margin: 1.0 | |||
| n_negative: 5 | |||
| regul_rate: 0.0 | |||
| loss_fn: "margin" | |||
| @@ -0,0 +1,11 @@ | |||
| defaults: | |||
| - training | |||
| - _self_ | |||
| model_trainer: | |||
| _target_: training.model_trainers.translation.trans_r_trainer.TransRTrainer | |||
| margin: 1.0 | |||
| n_negative: 10 | |||
| regul_rate: 0.0 | |||
| loss_fn: "margin" | |||
| @@ -0,0 +1,17 @@ | |||
| defaults: | |||
| - config | |||
| - override common: common | |||
| - override data: fb15k | |||
| - override model: trans_e | |||
| - override training: trans_e_trainer | |||
| - _self_ | |||
| training: | |||
| model_trainer: | |||
| loss_fn: "margin" | |||
| common: | |||
| run_name: "TransE_FB15K" | |||
| evaluate_only: true | |||
| load_path: (2, 1800) | |||
| @@ -0,0 +1,17 @@ | |||
| defaults: | |||
| - config | |||
| - override common: common | |||
| - override data: wn18 | |||
| - override model: trans_e | |||
| - override training: trans_e_trainer | |||
| - _self_ | |||
| training: | |||
| model_trainer: | |||
| loss_fn: "margin" | |||
| common: | |||
| run_name: "TransE_WN18" | |||
| evaluate_only: true | |||
| load_path: (82, 3000) | |||
| @@ -0,0 +1,18 @@ | |||
| defaults: | |||
| - config | |||
| - override common: common | |||
| - override data: wn18rr | |||
| - override model: trans_e | |||
| - override training: trans_e_trainer | |||
| - _self_ | |||
| training: | |||
| model_trainer: | |||
| loss_fn: "margin" | |||
| common: | |||
| run_name: "TransE_WN18rr" | |||
| evaluate_only: true | |||
| load_path: (2, 6000) | |||
| @@ -0,0 +1,17 @@ | |||
| defaults: | |||
| - config | |||
| - override common: common | |||
| - override data: yago3_10 | |||
| - override model: trans_e | |||
| - override training: trans_e_trainer | |||
| - _self_ | |||
| training: | |||
| model_trainer: | |||
| loss_fn: "margin" | |||
| common: | |||
| run_name: "TransE_YAGO310" | |||
| evaluate_only: true | |||
| load_path: (1, 2600) | |||
| @@ -0,0 +1,17 @@ | |||
| defaults: | |||
| - config | |||
| - override common: common | |||
| - override data: wn18 | |||
| - override model: trans_r | |||
| - override training: trans_r_trainer | |||
| - _self_ | |||
| training: | |||
| model_trainer: | |||
| loss_fn: "margin" | |||
| common: | |||
| run_name: "TransR_WN18" | |||
| # evaluate_only: true | |||
| # load_path: (1, 1000) | |||
| @@ -0,0 +1,6 @@ | |||
| from .fb15k import FB15KDataset | |||
| from .wn18 import WN18Dataset, WN18RRDataset | |||
| from .yago3_10 import YAGO310Dataset | |||
| from .hationet import HetionetDataset | |||
| from .open_bio_link import OpenBioLinkDataset | |||
| from .openke_wiki import OpenKEWikiDataset | |||
| @@ -0,0 +1,49 @@ | |||
| from typing import List, Tuple | |||
| import torch | |||
| from pykeen.datasets import FB15k | |||
| from .kg_dataset import KGDataset | |||
| class FB15KDataset(KGDataset): | |||
| """A wrapper around PyKeen's FB15k dataset producing KGDataset instances.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| dataset = FB15k() | |||
| # if split == "train": | |||
| if True: | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'" | |||
| raise ValueError(msg) | |||
| custom_triples = torch.load( | |||
| "/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt", | |||
| map_location=torch.device("cpu"), | |||
| ).to(torch.long) | |||
| # .tolist() | |||
| tf = tf.clone_and_exchange_triples( | |||
| custom_triples, | |||
| ) | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=dataset.num_entities, | |||
| num_relations=dataset.num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| return torch.tensor(super().__getitem__(idx), dtype=torch.long) | |||
| @@ -0,0 +1,57 @@ | |||
| from typing import List, Tuple | |||
| import torch | |||
| from pykeen.datasets import Hetionet | |||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
| class HetionetDataset(KGDataset): | |||
| """A wrapper around PyKeen's Hetionet dataset that produces KGDataset instances.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| """ | |||
| Initialize the WN18RR wrapper. | |||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||
| """ | |||
| # Load the PyKeen WN18 dataset | |||
| dataset = Hetionet() | |||
| # Select the appropriate TriplesFactory based on the requested split | |||
| if split == "train": | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
| # Convert each row into a Python tuple of ints | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| # Get the total number of entities and relations in the entire dataset | |||
| num_entities = dataset.num_entities | |||
| num_relations = dataset.num_relations | |||
| # Initialize the base KGDataset with the selected triples | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=num_entities, | |||
| num_relations=num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| data = super().__getitem__(idx) | |||
| # Convert the tuple to a torch tensor | |||
| return torch.tensor(data, dtype=torch.long) | |||
| @@ -0,0 +1,89 @@ | |||
| """PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||
| from pathlib import Path | |||
| from typing import Dict, List, Tuple | |||
| import torch | |||
| from pykeen.typing import MappedTriples | |||
| from torch.utils.data import Dataset | |||
| class KGDataset(Dataset): | |||
| """PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||
| def __init__( | |||
| self, | |||
| triples_factory: MappedTriples, | |||
| triples: List[Tuple[int, int, int]], | |||
| num_entities: int, | |||
| num_relations: int, | |||
| split: str = "train", | |||
| ) -> None: | |||
| super().__init__() | |||
| self.triples_factory = triples_factory | |||
| self.triples = torch.tensor(triples, dtype=torch.long) | |||
| self.num_entities = num_entities | |||
| self.num_relations = num_relations | |||
| self.split = split | |||
| def __len__(self) -> int: | |||
| """Return the number of triples in the dataset.""" | |||
| return len(self.triples) | |||
| def __getitem__(self, idx: int) -> Tuple[int, int, int]: | |||
| """Return the indexed triple at the given index.""" | |||
| return self.triples[idx] | |||
| def entities(self) -> torch.Tensor: | |||
| """Return a list of all entities in the dataset.""" | |||
| return torch.arange(self.num_entities) | |||
| def relations(self) -> torch.Tensor: | |||
| """Return a list of all relations in the dataset.""" | |||
| return torch.arange(self.num_relations) | |||
| def __repr__(self) -> str: | |||
| return ( | |||
| f"{self.__class__.__name__}(split={self.split!r}, " | |||
| f"num_triples={len(self.triples)}, " | |||
| f"num_entities={self.num_entities}, " | |||
| f"num_relations={self.num_relations})" | |||
| ) | |||
| # ---------------------------- Helper functions -------------------------------- | |||
| def _map_to_ids( | |||
| triples: List[Tuple[str, str, str]], | |||
| ) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]: | |||
| ent2id: Dict[str, int] = {} | |||
| rel2id: Dict[str, int] = {} | |||
| mapped: List[Tuple[int, int, int]] = [] | |||
| def _id(d: Dict[str, int], k: str) -> int: | |||
| d.setdefault(k, len(d)) | |||
| return d[k] | |||
| for h, r, t in triples: | |||
| mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t))) | |||
| return mapped, ent2id, rel2id | |||
| def load_kg_from_files( | |||
| root: str, splits: tuple[str] = ("train", "valid", "test") | |||
| ) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]: | |||
| root = Path(root) | |||
| split_raw, all_raw = {}, [] | |||
| for s in splits: | |||
| lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()] | |||
| split_raw[s] = lines | |||
| all_raw.extend(lines) | |||
| mapped, ent2id, rel2id = _map_to_ids(all_raw) | |||
| datasets, start = {}, 0 | |||
| for s in splits: | |||
| end = start + len(split_raw[s]) | |||
| datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s) | |||
| start = end | |||
| return datasets, ent2id, rel2id | |||
| @@ -0,0 +1,57 @@ | |||
| from typing import List, Tuple | |||
| import torch | |||
| from pykeen.datasets import OpenBioLink | |||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
| class OpenBioLinkDataset(KGDataset): | |||
| """A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| """ | |||
| Initialize the WN18RR wrapper. | |||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||
| """ | |||
| # Load the PyKeen WN18 dataset | |||
| dataset = OpenBioLink() | |||
| # Select the appropriate TriplesFactory based on the requested split | |||
| if split == "train": | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
| # Convert each row into a Python tuple of ints | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| # Get the total number of entities and relations in the entire dataset | |||
| num_entities = dataset.num_entities | |||
| num_relations = dataset.num_relations | |||
| # Initialize the base KGDataset with the selected triples | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=num_entities, | |||
| num_relations=num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| data = super().__getitem__(idx) | |||
| # Convert the tuple to a torch tensor | |||
| return torch.tensor(data, dtype=torch.long) | |||
| @@ -0,0 +1,57 @@ | |||
| from typing import List, Tuple | |||
| import torch | |||
| from pykeen.datasets import Wikidata5M | |||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
| class OpenKEWikiDataset(KGDataset): | |||
| """A wrapper around PyKeen's OpenKE-Wiki (WikiN3) dataset that produces KGDataset instances.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| """ | |||
| Initialize the WN18RR wrapper. | |||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||
| """ | |||
| # Load the PyKeen WN18 dataset | |||
| dataset = Wikidata5M() | |||
| # Select the appropriate TriplesFactory based on the requested split | |||
| if split == "train": | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
| # Convert each row into a Python tuple of ints | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| # Get the total number of entities and relations in the entire dataset | |||
| num_entities = dataset.num_entities | |||
| num_relations = dataset.num_relations | |||
| # Initialize the base KGDataset with the selected triples | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=num_entities, | |||
| num_relations=num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| data = super().__getitem__(idx) | |||
| # Convert the tuple to a torch tensor | |||
| return torch.tensor(data, dtype=torch.long) | |||
| @@ -0,0 +1,112 @@ | |||
| from typing import List, Tuple | |||
| import torch | |||
| from pykeen.datasets import WN18, WN18RR | |||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
| class WN18Dataset(KGDataset): | |||
| """A wrapper around PyKeen's WN18 dataset that produces KGDataset instances.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| """ | |||
| Initialize the WN18RR wrapper. | |||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||
| """ | |||
| # Load the PyKeen WN18 dataset | |||
| dataset = WN18() | |||
| # Select the appropriate TriplesFactory based on the requested split | |||
| if split == "train": | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
| # Convert each row into a Python tuple of ints | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| # Get the total number of entities and relations in the entire dataset | |||
| num_entities = dataset.num_entities | |||
| num_relations = dataset.num_relations | |||
| # Initialize the base KGDataset with the selected triples | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=num_entities, | |||
| num_relations=num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| data = super().__getitem__(idx) | |||
| # Convert the tuple to a torch tensor | |||
| return torch.tensor(data, dtype=torch.long) | |||
| # Assuming KGDataset is defined in the same module or already imported: | |||
| # from your_module import KGDataset | |||
| class WN18RRDataset(KGDataset): | |||
| """A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| """ | |||
| Initialize the WN18RR wrapper. | |||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||
| """ | |||
| # Load the PyKeen WN18RR dataset | |||
| dataset = WN18RR() | |||
| # Select the appropriate TriplesFactory based on the requested split | |||
| if split == "train": | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
| # Convert each row into a Python tuple of ints | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| # Get the total number of entities and relations in the entire dataset | |||
| num_entities = dataset.num_entities | |||
| num_relations = dataset.num_relations | |||
| # Initialize the base KGDataset with the selected triples | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=num_entities, | |||
| num_relations=num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| data = super().__getitem__(idx) | |||
| # Convert the tuple to a torch tensor | |||
| return torch.tensor(data, dtype=torch.long) | |||
| @@ -0,0 +1,70 @@ | |||
| from typing import List, Tuple | |||
| import torch | |||
| from pykeen.datasets import PathDataset | |||
| from .kg_dataset import KGDataset | |||
| class YAGO310Dataset(KGDataset): | |||
| """Wrapper around PyKeen's YAGO3-10 dataset.""" | |||
| def __init__(self, split: str = "train") -> None: | |||
| # dataset = YAGO310() | |||
| dataset = PathDataset( | |||
| training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt", | |||
| testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt", | |||
| validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt", | |||
| ) | |||
| if split == "train": | |||
| # if True: | |||
| tf = dataset.training | |||
| elif split in ("valid", "val", "validation"): | |||
| tf = dataset.validation | |||
| elif split in ("test", "testing"): | |||
| tf = dataset.testing | |||
| else: | |||
| msg = f"Unrecognized split '{split}'" | |||
| raise ValueError(msg) | |||
| custom_triples = torch.load( | |||
| "/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt", | |||
| map_location=torch.device("cpu"), | |||
| ).to(torch.long) | |||
| # .tolist() | |||
| # get a random sample of size 5000 | |||
| # seed = torch.seed() # For reproducibility | |||
| # torch.manual_seed(seed) | |||
| # torch.cuda.manual_seed(seed) | |||
| # if custom_triples.shape[0] > 5000: | |||
| # custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]] | |||
| tf = tf.clone_and_exchange_triples( | |||
| custom_triples, | |||
| ) | |||
| # triples = tf.mapped_triples | |||
| # # get a random sample of size 5000 | |||
| # if triples.shape[0] > 5000: | |||
| # indices = torch.randperm(triples.shape[0])[:5000] | |||
| # triples = triples[indices] | |||
| # tf = tf.clone_and_exchange_triples(triples) | |||
| triples_list: List[Tuple[int, int, int]] = [ | |||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
| ] | |||
| super().__init__( | |||
| triples_factory=tf, | |||
| triples=triples_list, | |||
| num_entities=dataset.num_entities, | |||
| num_relations=dataset.num_relations, | |||
| split=split, | |||
| ) | |||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||
| return torch.tensor(super().__getitem__(idx), dtype=torch.long) | |||
| @@ -0,0 +1,38 @@ | |||
| from data import ( | |||
| FB15KDataset, | |||
| WN18Dataset, | |||
| WN18RRDataset, | |||
| YAGO310Dataset, | |||
| ) | |||
| from metrics.wlcrec import WLCREC | |||
| def main() -> None: | |||
| datasets = { | |||
| "YAGO3-10": YAGO310Dataset(split="train"), | |||
| # "WN18": WN18Dataset(split="train"), | |||
| # "WN18RR": WN18RRDataset(split="train"), | |||
| # "FB15K": FB15KDataset(split="train"), | |||
| # "Hetionet": HetionetDataset(split="train"), | |||
| # "OpenBioLink": OpenBioLinkDataset(split="train"), | |||
| # "OpenKEWiki": OpenKEWikiDataset(split="train"), | |||
| } | |||
| for name, dataset in datasets.items(): | |||
| wl_crec = WLCREC(dataset) | |||
| entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = wl_crec.compute(H=20) | |||
| print( | |||
| f"\nDataset: {name}", | |||
| f"\nResults (H={20})", | |||
| f"\n • avg WLEC : {entropy:.6f} nats", | |||
| f"\n • C_ratio : {c_ratio:.6f}", | |||
| f"\n • C_NWLEC : {c_nwlec:.6f}", | |||
| f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||
| f"\n • D_ratio : {d_ratio:.6f}", | |||
| f"\n • D_NWLEC : {d_nwlec:.6f}", | |||
| sep="", | |||
| ) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,80 @@ | |||
| from copy import deepcopy | |||
| import hydra | |||
| import lovely_tensors as lt | |||
| import torch | |||
| from hydra.utils import instantiate | |||
| from omegaconf import DictConfig | |||
| from metrics.c_swklf import CSWKLF | |||
| from tools import get_pretty_logger | |||
| # (Import whatever Trainer or runner you have) | |||
| from training.trainer import Trainer | |||
| logger = get_pretty_logger(__name__) | |||
| lt.monkey_patch() | |||
| @hydra.main(config_path="configs", config_name="config") | |||
| def main(cfg: DictConfig) -> None: | |||
| # Detect CUDA | |||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| common_params = cfg.common | |||
| training_params = deepcopy(cfg.training) | |||
| # drop the model_trainer from cfg.trainer | |||
| del training_params.model_trainer | |||
| common_params = instantiate(common_params) | |||
| training_params = instantiate(training_params) | |||
| # dataset instantiation | |||
| train_dataset = instantiate(cfg.data.train) | |||
| val_dataset = instantiate(cfg.data.valid) | |||
| model = instantiate( | |||
| cfg.model, | |||
| num_entities=train_dataset.num_entities, | |||
| num_relations=train_dataset.num_relations, | |||
| triples_factory=train_dataset.triples_factory, | |||
| device=device, | |||
| ) | |||
| model = model.to(device) | |||
| model_trainer = instantiate( | |||
| cfg.training.model_trainer, | |||
| model=model, | |||
| training_params=training_params, | |||
| triples_factory=train_dataset.triples_factory, | |||
| mapped_triples=train_dataset.triples, | |||
| device=model.device, | |||
| ) | |||
| # Instantiate the Trainer with the model_trainer | |||
| trainer = Trainer( | |||
| train_dataset=train_dataset, | |||
| val_dataset=val_dataset, | |||
| model_trainer=model_trainer, | |||
| common_params=common_params, | |||
| training_params=training_params, | |||
| device=model.device, | |||
| ) | |||
| if cfg.common.evaluate_only: | |||
| trainer.evaluate() | |||
| c_swklf_metric = CSWKLF( | |||
| dataset=val_dataset, | |||
| model=model, | |||
| ) | |||
| c_swklf_value = c_swklf_metric.compute() | |||
| logger.info(f"CSWKLF Metric Value: {c_swklf_value}") | |||
| else: | |||
| trainer.run() | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,55 @@ | |||
| from __future__ import annotations | |||
| from abc import ABC, abstractmethod | |||
| from typing import TYPE_CHECKING, Any, List, Tuple | |||
| if TYPE_CHECKING: | |||
| from data.kg_dataset import KGDataset | |||
| class BaseMetric(ABC): | |||
| """Base class for metrics that compute Weisfeiler-Lehman Entropy | |||
| Complexity (WLEC) or similar metrics. | |||
| """ | |||
| def __init__(self, dataset: KGDataset) -> None: | |||
| self.dataset = dataset | |||
| self.num_entities = dataset.num_entities | |||
| self.out_adj, self.in_adj = self._build_adjacency_lists() | |||
| def _build_adjacency_lists( | |||
| self, | |||
| ) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]: | |||
| """Create *in* and *out* adjacency lists from mapped triples. | |||
| Parameters | |||
| ---------- | |||
| triples: | |||
| Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``). | |||
| num_entities: | |||
| Total number of entities. Needed to allocate adjacency containers. | |||
| Returns | |||
| ------- | |||
| out_adj, in_adj | |||
| Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element | |||
| ``in_adj[v]`` is a list ``[(r, h), ...]``. | |||
| """ | |||
| triples = self.dataset.triples # (m, 3) | |||
| num_entities = self.num_entities | |||
| out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||
| in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||
| for h, r, t in triples.tolist(): | |||
| out_adj[h].append((r, t)) | |||
| in_adj[t].append((r, h)) | |||
| return out_adj, in_adj | |||
| @abstractmethod | |||
| def compute( | |||
| self, | |||
| *args, | |||
| **kwargs, | |||
| ) -> Any: ... | |||
| @@ -0,0 +1,264 @@ | |||
| from __future__ import annotations | |||
| import math | |||
| from collections import defaultdict | |||
| from dataclasses import dataclass | |||
| from typing import TYPE_CHECKING, Callable, Iterable | |||
| import torch | |||
| from torch import Tensor | |||
| from .base_metric import BaseMetric | |||
| if TYPE_CHECKING: | |||
| from data.kg_dataset import KGDataset | |||
| from models.base_model import ERModel | |||
| @torch.no_grad() | |||
| def _log_prob_distribution( | |||
| fn: Callable[[Tensor], Tensor], | |||
| query: Tensor, | |||
| batch_size: int, | |||
| ) -> Iterable[Tensor]: | |||
| """ | |||
| Yield *log*-probability batches for an arbitrary distribution helper. | |||
| Parameters | |||
| ---------- | |||
| fn : | |||
| A bound method of *model* returning **log**-probabilities, e.g. | |||
| ``model.tail_distribution(..., log=True)``. | |||
| query : | |||
| LongTensor of shape (N, 2) containing the query IDs that `fn` expects. | |||
| batch_size : | |||
| Mini-batch size. Choose according to available GPU/CPU memory. | |||
| """ | |||
| for i in range(0, len(query), batch_size): | |||
| yield fn(query[i : i + batch_size]) # (B, |Candidates|) | |||
| def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor: | |||
| """ | |||
| KL(q || p) where q is uniform over *true_index_sets*. | |||
| Parameters | |||
| ---------- | |||
| log_p : | |||
| Tensor of shape (B, N) — log-probabilities from the model. | |||
| true_index_sets : | |||
| List (length B). Element *b* contains the list of **indices** | |||
| that are correct for row *b*. | |||
| Returns | |||
| ------- | |||
| kl : Tensor, shape (B,) — KL divergence per row (natural log). | |||
| """ | |||
| rows: list[Tensor] = [] | |||
| for lp_row, idx in zip(log_p, true_index_sets): | |||
| k = len(idx) | |||
| rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p] | |||
| return torch.tensor(rows, device=log_p.device) | |||
| @dataclass(slots=True) | |||
| class _Accum: | |||
| """Simple accumulator for ∑ value and ∑ weight.""" | |||
| tot_v: float = 0.0 | |||
| tot_w: float = 0.0 | |||
| def update(self, value: float, weight: float) -> None: | |||
| self.tot_v += value * weight | |||
| self.tot_w += weight | |||
| @property | |||
| def mean(self) -> float: | |||
| return self.tot_v / self.tot_w if self.tot_w else float("nan") | |||
| class CSWKLF(BaseMetric): | |||
| """ | |||
| C-SWKLF metric for evaluating knowledge graph embeddings. | |||
| This metric is used to evaluate the quality of knowledge graph embeddings | |||
| by computing the C-SWKLF score. | |||
| """ | |||
| def __init__(self, dataset: KGDataset, model: ERModel) -> None: | |||
| super().__init__(dataset) | |||
| self.model = model | |||
| @torch.no_grad() | |||
| def compute( | |||
| self, | |||
| alpha: float = 1 / 3, | |||
| beta: float = 1 / 3, | |||
| gamma: float = 1 / 3, | |||
| batch_size: int = 1024, | |||
| slice_size: int | None = 2048, | |||
| device: torch.device | str | None = None, | |||
| ) -> float: | |||
| """ | |||
| Compute *Comprehensive Structural-Weighted KL-Fitness*. | |||
| Parameters | |||
| ---------- | |||
| model : | |||
| A trained PyKEEN ERModel **with** the three distribution helpers. | |||
| kg : | |||
| `pykeen.triples.KGInfo` instance providing `mapped_triples`. | |||
| alpha, beta, gamma : | |||
| Non-negative weights for the three query types, default 1/3 each. | |||
| Must sum to *1*. | |||
| batch_size : | |||
| Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM. | |||
| slice_size : | |||
| Forwarded to `tail_distribution` / `head_distribution` / | |||
| `relation_distribution`. `None` disables slicing. | |||
| device : | |||
| Where to do the computation. `None` ⇒ first param's device. | |||
| Returns | |||
| ------- | |||
| score : float in (0, 1] | |||
| Higher = model closer to empirical distributions across *all* directions. | |||
| """ | |||
| # --------------------------------------------------------------------- # | |||
| # Preparation # | |||
| # --------------------------------------------------------------------- # | |||
| assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1." | |||
| model_device = next(self.model.parameters()).device | |||
| if device is None: | |||
| device = model_device | |||
| triples = self.dataset.triples.to(device) # (|T|, 3) | |||
| heads, rels, tails = triples.t() # (|T|,) | |||
| # Build index structures -------------------------------------------------- | |||
| tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list) | |||
| heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list) | |||
| rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list) | |||
| for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()): | |||
| tails_by_hr[(h, r)].append(t) | |||
| heads_by_rt[(r, t)].append(h) | |||
| rels_by_ht[(h, t)].append(r) | |||
| # Structural entropies & query lists --------------------------------------- | |||
| H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r | |||
| H_head: dict[int, _Accum] = defaultdict(_Accum) | |||
| tail_queries: list[tuple[int, int]] = [] # (h,r) | |||
| head_queries: list[tuple[int, int]] = [] # (r,t) | |||
| rel_queries: list[tuple[int, int]] = [] # (h,t) | |||
| # --- tails -------------------------------------------------------------- | |||
| for (h, r), ts in tails_by_hr.items(): | |||
| k = len(ts) | |||
| H_tail[r].update(math.log(k), 1.0) | |||
| tail_queries.append((h, r)) | |||
| # --- heads -------------------------------------------------------------- | |||
| for (r, t), hs in heads_by_rt.items(): | |||
| k = len(hs) | |||
| H_head[r].update(math.log(k), 1.0) | |||
| head_queries.append((r, t)) | |||
| # --- relations ---------------------------------------------------------- | |||
| H_rel_accum = _Accum() | |||
| for (h, t), rs in rels_by_ht.items(): | |||
| k = len(rs) | |||
| H_rel_accum.update(math.log(k), 1.0) | |||
| rel_queries.append((h, t)) | |||
| H_struct_tail = {r: acc.mean for r, acc in H_tail.items()} | |||
| H_struct_head = {r: acc.mean for r, acc in H_head.items()} | |||
| # H_struct_rel = H_rel_accum.mean # scalar | |||
| # --------------------------------------------------------------------- # | |||
| # KL divergences # | |||
| # --------------------------------------------------------------------- # | |||
| def _avg_kl_per_relation( | |||
| queries: list[tuple[int, int]], | |||
| true_idx_map: dict[tuple[int, int], list[int]], | |||
| distribution_fn: Callable[[Tensor], Tensor], | |||
| num_relations: bool = False, # switch to know output size | |||
| ) -> dict[int, float]: | |||
| """ | |||
| Compute *average* KL(q || p) **per relation**. | |||
| Works for the tail & head conditionals. | |||
| """ | |||
| kl_sum: dict[int, float] = defaultdict(float) | |||
| count: dict[int, int] = defaultdict(int) | |||
| query_tensor = torch.tensor(queries, dtype=torch.long, device=device) | |||
| for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size): | |||
| # slice_size handled inside distribution_fn | |||
| # log_p : (B, |Candidates|) | |||
| b = log_p.shape[0] | |||
| # Map back to current slice of queries | |||
| slice_queries = queries[:b] | |||
| queries[:] = queries[b:] # consume | |||
| true_sets = [true_idx_map[q] for q in slice_queries] | |||
| kl_row = _kl_uniform(log_p, true_sets) # (B,) | |||
| for (x, r), kl_val in zip(slice_queries, kl_row.tolist()): | |||
| rel_id = r if not num_relations else x | |||
| kl_sum[rel_id] += kl_val | |||
| count[rel_id] += 1 | |||
| return {r: kl_sum[r] / count[r] for r in kl_sum} | |||
| # --- tails -------------------------------------------------------------- | |||
| tail_queries_copy = tail_queries.copy() | |||
| D_tail = _avg_kl_per_relation( | |||
| tail_queries_copy, | |||
| tails_by_hr, | |||
| lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s), | |||
| ) | |||
| S_tail = {r: math.exp(-d) for r, d in D_tail.items()} | |||
| # --- heads -------------------------------------------------------------- | |||
| head_queries_copy = head_queries.copy() | |||
| D_head = _avg_kl_per_relation( | |||
| head_queries_copy, | |||
| heads_by_rt, | |||
| lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s), | |||
| num_relations=True, | |||
| ) | |||
| S_head = {r: math.exp(-d) for r, d in D_head.items()} | |||
| # --- relations (single global number) ---------------------------------- | |||
| D_rel_sum = 0.0 | |||
| for log_p in _log_prob_distribution( | |||
| lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s), | |||
| torch.tensor(rel_queries, dtype=torch.long, device=device), | |||
| batch_size, | |||
| ): | |||
| slice_size_batch = log_p.shape[0] | |||
| slice_queries = rel_queries[:slice_size_batch] | |||
| rel_queries[:] = rel_queries[slice_size_batch:] | |||
| true_sets = [rels_by_ht[q] for q in slice_queries] | |||
| D_rel_sum += _kl_uniform(log_p, true_sets).sum().item() | |||
| D_rel = D_rel_sum / H_rel_accum.tot_w | |||
| S_rel = math.exp(-D_rel) | |||
| # --------------------------------------------------------------------- # | |||
| # Weighted aggregation → C-SWKLF # | |||
| # --------------------------------------------------------------------- # | |||
| def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float: | |||
| num = sum(weight[r] * score[r] for r in score) | |||
| den = sum(weight[r] for r in score) | |||
| return num / den if den else float("nan") | |||
| sw_tail = _weighted_mean(S_tail, H_struct_tail) | |||
| sw_head = _weighted_mean(S_head, H_struct_head) | |||
| # Relations: numerator & denominator cancel | |||
| sw_rel = S_rel | |||
| return alpha * sw_tail + beta * sw_head + gamma * sw_rel | |||
| @@ -0,0 +1,732 @@ | |||
| # from __future__ import annotations | |||
| # import math | |||
| # import os | |||
| # import random | |||
| # import threading | |||
| # from collections import defaultdict | |||
| # from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| # from typing import Dict, List, Tuple | |||
| # import torch | |||
| # from data.kg_dataset import KGDataset | |||
| # from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in | |||
| # from tools import get_pretty_logger | |||
| # logger = get_pretty_logger(__name__) | |||
| # # -------------------------------------------------------------------- | |||
| # # Edge-editing primitives | |||
| # # -------------------------------------------------------------------- | |||
| # # -- additions -------------------- | |||
| # def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||
| # for _ in range(1000): | |||
| # h = rng.randrange(n_ent) | |||
| # t = rng.randrange(n_ent) | |||
| # if colours[h] != colours[t]: | |||
| # r = rng.randrange(n_rel) | |||
| # triples.append((h, r, t)) | |||
| # out_adj[h].append((r, t)) | |||
| # in_adj[t].append((r, h)) | |||
| # return | |||
| # def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||
| # σ = rng.choice(colours) | |||
| # τ = rng.choice(colours) | |||
| # h = colours.index(σ) | |||
| # t = colours.index(τ) | |||
| # triples.append((h, 0, t)) | |||
| # out_adj[h].append((0, t)) | |||
| # in_adj[t].append((0, h)) | |||
| # # -- removals -------------------- | |||
| # def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
| # cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||
| # if cand: | |||
| # idx = rng.choice(cand) | |||
| # h, r, t = triples.pop(idx) | |||
| # out_adj[h].remove((r, t)) | |||
| # in_adj[t].remove((r, h)) | |||
| # def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
| # sig_rel = defaultdict(set) | |||
| # for h, r, t in triples: | |||
| # σ, τ = colours[h], colours[t] | |||
| # sig_rel[(σ, τ)].add(r) | |||
| # cand = [] | |||
| # for i, (h, r, t) in enumerate(triples): | |||
| # if len(sig_rel[(colours[h], colours[t])]) == 1: | |||
| # cand.append(i) | |||
| # if cand: | |||
| # idx = rng.choice(cand) | |||
| # h, r, t = triples.pop(idx) | |||
| # out_adj[h].remove((r, t)) | |||
| # in_adj[t].remove((r, h)) | |||
| # def _search_worker( | |||
| # seed: int, | |||
| # triples_init: List[Tuple[int, int, int]], | |||
| # triples_factory, | |||
| # n_ent: int, | |||
| # n_rel: int, | |||
| # depth: int, | |||
| # lo: float, | |||
| # hi: float, | |||
| # max_iters: int, | |||
| # ) -> List[Tuple[int, int, int]] | None: | |||
| # """ | |||
| # Run the exact same hill‑climb that `tune_crec()` did, | |||
| # but entirely in this process. Return the edited triples | |||
| # once c falls in [lo, hi]; return None if we used up all iterations. | |||
| # """ | |||
| # rng = random.Random(seed) | |||
| # triples = triples_init.copy().tolist() | |||
| # for it in range(max_iters): | |||
| # # WL‑CREC is *only* recomputed every 1000 edits, exactly like before | |||
| # if it % 1000 == 0: | |||
| # dataset = KGDataset( | |||
| # triples_factory, | |||
| # triples=torch.tensor(triples, dtype=torch.long), | |||
| # num_entities=n_ent, | |||
| # num_relations=n_rel, | |||
| # ) | |||
| # wl = WLCREC(dataset) | |||
| # colours = wl.wl_colours(depth) | |||
| # *_, c, _ = wl.compute(H=5) # unchanged API | |||
| # if lo <= c <= hi: # success | |||
| # logger.info("[seed %d] hit %.4f in %d edits", seed, c, it) | |||
| # return triples | |||
| # # ---------- identical edit logic ---------- | |||
| # if c < lo: | |||
| # if rng.random() < 0.5: | |||
| # rem_det_edge(triples, colours, depth, rng) | |||
| # else: | |||
| # add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||
| # else: # c > hi | |||
| # if rng.random() < 0.5: | |||
| # rem_div_edge(triples, colours, depth, rng) | |||
| # else: | |||
| # add_det_edge(triples, colours, depth, n_rel, rng) | |||
| # # ------------------------------------------ | |||
| # return None # used up our budget | |||
| # # -------------------------------------------------------------------- | |||
| # # Unified tuner | |||
| # # -------------------------------------------------------------------- | |||
| # def tune_crec( | |||
| # wl_crec: WLCREC, | |||
| # n_ent: int, | |||
| # n_rel: int, | |||
| # depth: int, | |||
| # lo: float, | |||
| # hi: float, | |||
| # max_iters: int = 80_000, | |||
| # seed: int = 42, | |||
| # ): | |||
| # triples = wl_crec.dataset.triples.tolist() | |||
| # rng = random.Random(seed) | |||
| # for it in range(max_iters): | |||
| # print(f"\r[iter {it + 1:5d}] ", end="") | |||
| # if it % 1000 == 0: | |||
| # dataset = KGDataset( | |||
| # wl_crec.dataset.triples_factory, | |||
| # triples=torch.tensor(triples, dtype=torch.long), | |||
| # num_entities=n_ent, | |||
| # num_relations=n_rel, | |||
| # ) | |||
| # tmp_wl_crec = WLCREC(dataset) | |||
| # # colours = wl_colours(triples, n_ent, depth) | |||
| # colours = tmp_wl_crec.wl_colours(depth) | |||
| # _, _, _, _, c, _ = tmp_wl_crec.compute(H=5) | |||
| # if lo <= c <= hi: | |||
| # logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples)) | |||
| # return triples | |||
| # if c < lo: | |||
| # # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying | |||
| # if rng.random() < 0.5: | |||
| # rem_det_edge(triples, colours, depth, rng) | |||
| # else: | |||
| # add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||
| # # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic | |||
| # elif rng.random() < 0.5: | |||
| # rem_div_edge(triples, colours, depth, rng) | |||
| # else: | |||
| # add_det_edge(triples, colours, depth, n_rel, rng) | |||
| # if (it + 1) % 10000 == 1: | |||
| # logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples)) | |||
| # raise RuntimeError("Exceeded max iterations without hitting target band.") | |||
| # def _edit_batch( | |||
| # worker_id: int, | |||
| # n_edits: int, | |||
| # colours: List[int], | |||
| # crec: float, | |||
| # target_lo: float, | |||
| # target_hi: float, | |||
| # depth: int, | |||
| # n_ent: int, | |||
| # n_rel: int, | |||
| # seed_base: int, | |||
| # ): | |||
| # """ | |||
| # Perform `n_edits` topology modifications *locally* and return | |||
| # (added_triples, removed_triples) lists. | |||
| # """ | |||
| # rng = random.Random(seed_base + worker_id) | |||
| # local_added, local_removed = [], [] | |||
| # # local graph views: only sizes matter for edit selection | |||
| # for _ in range(n_edits): | |||
| # if crec < target_lo: # need ↑ WL-CREC | |||
| # if rng.random() < 0.5: # • remove deterministic | |||
| # # We cannot remove by index safely without the global list; | |||
| # # choose a *signature* and remember intention to delete: | |||
| # local_removed.append(("det", rng.random())) | |||
| # else: # • add diversifying | |||
| # h = rng.randrange(n_ent) | |||
| # t = rng.randrange(n_ent) | |||
| # while colours[h] == colours[t]: | |||
| # h = rng.randrange(n_ent) | |||
| # t = rng.randrange(n_ent) | |||
| # r = rng.randrange(n_rel) | |||
| # local_added.append((h, r, t)) | |||
| # else: # need ↓ WL-CREC | |||
| # if rng.random() < 0.5: # • remove diversifying | |||
| # local_removed.append(("div", rng.random())) | |||
| # else: # • add deterministic | |||
| # σ = rng.choice(colours) | |||
| # τ = rng.choice(colours) | |||
| # h = colours.index(σ) | |||
| # t = colours.index(τ) | |||
| # local_added.append((h, 0, t)) | |||
| # return local_added, local_removed | |||
| # def wl_colours(triples, n_ent, depth): | |||
| # # build adjacency once | |||
| # out_adj, in_adj = defaultdict(list), defaultdict(list) | |||
| # for h, r, t in triples: | |||
| # out_adj[h].append((r, t)) | |||
| # in_adj[t].append((r, h)) | |||
| # colours_rounds = [[0] * n_ent] # round-0 colours | |||
| # for h in range(1, depth + 1): | |||
| # prev = colours_rounds[-1] | |||
| # # 1) build textual signatures in parallel | |||
| # def sig(v): | |||
| # neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||
| # ("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||
| # ] | |||
| # neigh.sort() | |||
| # return (prev[v], tuple(neigh)) | |||
| # with ThreadPoolExecutor() as tpe: # cheap threads inside worker | |||
| # sigs = list(tpe.map(sig, range(n_ent))) | |||
| # # 2) assign deterministic colour IDs | |||
| # sig2id: Dict[Tuple, int] = {} | |||
| # next_round = [0] * n_ent | |||
| # fresh = 0 | |||
| # for v, sg in enumerate(sigs): | |||
| # cid = sig2id.setdefault(sg, fresh) | |||
| # if cid == fresh: | |||
| # fresh += 1 | |||
| # next_round[v] = cid | |||
| # colours_rounds.append(next_round) | |||
| # depth_colours = colours_rounds[-1] | |||
| # return depth_colours | |||
| # def _metric_worker(args): | |||
| # triples, n_ent, n_rel, depth = args | |||
| # dataset = KGDataset( | |||
| # triples_factory=None, # Not used in this context | |||
| # triples=torch.tensor(triples, dtype=torch.long), | |||
| # num_entities=n_ent, | |||
| # num_relations=n_rel, | |||
| # ) | |||
| # wl_crec = WLCREC(dataset) | |||
| # _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False) | |||
| # colours = wl_colours(triples, n_ent, depth) | |||
| # return c, colours | |||
| # def tune_crec_parallel_edits( | |||
| # triples_init, | |||
| # n_ent: int, | |||
| # n_rel: int, | |||
| # depth: int, | |||
| # target_lo: float, | |||
| # target_hi: float, | |||
| # max_iters: int = 80_000, | |||
| # metric_every: int = 100, | |||
| # n_workers: int = max(20, math.ceil(os.cpu_count() / 2)), | |||
| # seed: int = 42, | |||
| # ): | |||
| # # -------- shared mutable state (main thread owns it) -------- | |||
| # triples = triples_init.tolist() | |||
| # out_adj, in_adj = defaultdict(list), defaultdict(list) | |||
| # for h, r, t in triples: | |||
| # out_adj[h].append((r, t)) | |||
| # in_adj[t].append((r, h)) | |||
| # pool = ThreadPoolExecutor(max_workers=n_workers) | |||
| # metric_lock = threading.Lock() # exactly one metric at a time | |||
| # rng_global = random.Random(seed) | |||
| # # ----- first metric checkpoint ----- | |||
| # crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||
| # edit_budget_total = 0 | |||
| # for it in range(0, max_iters, metric_every): | |||
| # # ========================================================= | |||
| # # 1. PARALLEL EDIT STAGE (metric_every edits in total) | |||
| # # ========================================================= | |||
| # futures = [] | |||
| # edits_per_worker = metric_every // n_workers | |||
| # extra = metric_every % n_workers | |||
| # for wid in range(n_workers): | |||
| # n_edits = edits_per_worker + (1 if wid < extra else 0) | |||
| # futures.append( | |||
| # pool.submit( | |||
| # _edit_batch, | |||
| # wid, | |||
| # n_edits, | |||
| # colours, | |||
| # crec, | |||
| # target_lo, | |||
| # target_hi, | |||
| # depth, | |||
| # n_ent, | |||
| # n_rel, | |||
| # seed, | |||
| # ) | |||
| # ) | |||
| # # merge when workers finish | |||
| # for fut in as_completed(futures): | |||
| # added, removed_specs = fut.result() | |||
| # # --- apply additions immediately (cheap, conflict-free) | |||
| # for h, r, t in added: | |||
| # triples.append((h, r, t)) | |||
| # out_adj[h].append((r, t)) | |||
| # in_adj[t].append((r, h)) | |||
| # # --- apply removals: interpret the spec on *current* graph | |||
| # for typ, randv in removed_specs: | |||
| # if typ == "div": | |||
| # idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||
| # else: # 'det' | |||
| # sig_rel = defaultdict(set) | |||
| # for h, r, t in triples: | |||
| # sig_rel[(colours[h], colours[t])].add(r) | |||
| # idxs = [ | |||
| # i | |||
| # for i, (h, r, t) in enumerate(triples) | |||
| # if len(sig_rel[(colours[h], colours[t])]) == 1 | |||
| # ] | |||
| # if idxs: | |||
| # victim = idxs[int(randv * len(idxs))] | |||
| # h, r, t = triples.pop(victim) | |||
| # out_adj[h].remove((r, t)) | |||
| # in_adj[t].remove((r, h)) | |||
| # edit_budget_total += metric_every | |||
| # # ========================================================= | |||
| # # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised) | |||
| # # ========================================================= | |||
| # # if edit_budget_total % (10 * metric_every) == 0: | |||
| # if True: | |||
| # with metric_lock: | |||
| # crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||
| # logging.info( | |||
| # "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples) | |||
| # ) | |||
| # if target_lo <= crec <= target_hi: | |||
| # logging.info("Target band reached.") | |||
| # pool.shutdown(wait=True) | |||
| # return triples | |||
| # pool.shutdown(wait=True) | |||
| # raise RuntimeError("Exceeded max iteration budget without success.") | |||
| # from __future__ import annotations | |||
| # import logging | |||
| # import random | |||
| # from collections import defaultdict | |||
| # from concurrent.futures import ThreadPoolExecutor, as_completed | |||
| # from typing import List, Tuple | |||
| # import torch | |||
| # from data.kg_dataset import KGDataset | |||
| # from metrics.wlcrec import WLCREC | |||
| # from tools import get_pretty_logger | |||
| # logger = get_pretty_logger(__name__) | |||
| # # ------------ additions ------------ | |||
| # def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random): | |||
| # for _ in range(1000): | |||
| # h, t = rng.randrange(n_ent), rng.randrange(n_ent) | |||
| # if colours[h] != colours[t]: | |||
| # r = rng.randrange(n_rel) | |||
| # return ("add", (h, r, t)) | |||
| # return None # fell through – extremely rare | |||
| # def propose_add_det(colours, n_rel: int, rng: random.Random): | |||
| # σ, τ = rng.choice(colours), rng.choice(colours) | |||
| # h, t = colours.index(σ), colours.index(τ) | |||
| # return ("add", (h, 0, t)) # rel 0 = deterministic | |||
| # # ------------ removals ------------ | |||
| # def propose_rem_div(triples: List, colours, rng: random.Random): | |||
| # cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]] | |||
| # return ("rem", rng.choice(cand)) if cand else None | |||
| # def propose_rem_det(triples: List, colours, rng: random.Random): | |||
| # sig_rel = defaultdict(set) | |||
| # for h, r, t in triples: | |||
| # sig_rel[(colours[h], colours[t])].add(r) | |||
| # cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1] | |||
| # return ("rem", rng.choice(cand)) if cand else None | |||
| # def make_edit_proposal( | |||
| # triples_snapshot: List, | |||
| # colours_snapshot, | |||
| # c: float, | |||
| # lo: float, | |||
| # hi: float, | |||
| # n_ent: int, | |||
| # n_rel: int, | |||
| # depth: int, # still here for future use / signatures | |||
| # seed: int, | |||
| # ): | |||
| # """Return exactly one ('add' | 'rem', triple) proposal or None.""" | |||
| # rng = random.Random(seed) | |||
| # # -- decide which kind of edit we want, *given* the current c ------------ | |||
| # if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying) | |||
| # chooser = (propose_rem_det, propose_add_div) | |||
| # else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic) | |||
| # chooser = (propose_rem_div, propose_add_det) | |||
| # op = rng.choice(chooser) | |||
| # return ( | |||
| # op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng) | |||
| # if op.__name__.startswith("propose_add") | |||
| # else op(triples_snapshot, colours_snapshot, rng) | |||
| # ) | |||
| # def tune_crec_parallel_edits( | |||
| # triples: List, | |||
| # n_ent: int, | |||
| # n_rel: int, | |||
| # depth: int, | |||
| # lo: float, | |||
| # hi: float, | |||
| # *, | |||
| # max_iters: int = 80_000, | |||
| # edits_per_eval: int = 1000, # == old “if it % 1000 == 0” | |||
| # batch_size: int = 256, # how many proposals we farm out at once | |||
| # max_workers: int = 4, | |||
| # seed: int = 42, | |||
| # ) -> List: | |||
| # rng_global = random.Random(seed) | |||
| # triples = set(triples) # deduplicate if needed | |||
| # with ThreadPoolExecutor(max_workers=max_workers) as pool: | |||
| # proposal_seed = seed * 997 # deterministic but different stream | |||
| # # edit_counter = 0 | |||
| # # while edit_counter < max_iters: | |||
| # # # ----------------- expensive part (single‑thread) ---------------- | |||
| # # dataset = KGDataset( | |||
| # # triples_factory=None, # Not used in this context | |||
| # # triples=torch.tensor(triples, dtype=torch.long), | |||
| # # num_entities=n_ent, | |||
| # # num_relations=n_rel, | |||
| # # ) | |||
| # # tmp = WLCREC(dataset) | |||
| # # colours = tmp.wl_colours(depth) | |||
| # # *_, c, _ = tmp.compute(H=5) | |||
| # # # ----------------------------------------------------------------- | |||
| # # if lo <= c <= hi: | |||
| # # logger.info( | |||
| # # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples) | |||
| # # ) | |||
| # # return triples | |||
| # # ============ parallel block: just make `edits_per_eval` proposals | |||
| # needed = min(edits_per_eval, max_iters - edit_counter) | |||
| # proposals = [] | |||
| # while len(proposals) < needed: | |||
| # # launch a batch of workers | |||
| # futs = [ | |||
| # pool.submit( | |||
| # make_edit_proposal, | |||
| # triples, | |||
| # colours, | |||
| # c, | |||
| # lo, | |||
| # hi, | |||
| # n_ent, | |||
| # n_rel, | |||
| # depth, | |||
| # proposal_seed + i, | |||
| # ) | |||
| # for i in range(batch_size) | |||
| # ] | |||
| # for f in as_completed(futs): | |||
| # prop = f.result() | |||
| # if prop is not None: | |||
| # proposals.append(prop) | |||
| # if len(proposals) == needed: | |||
| # break | |||
| # proposal_seed += batch_size # move RNG window forward | |||
| # # -------------- apply the gathered proposals *sequentially* ------- | |||
| # for kind, trp in proposals: | |||
| # if kind == "add": | |||
| # triples.append(trp) | |||
| # else: # "rem" | |||
| # try: | |||
| # triples.remove(trp) | |||
| # except ValueError: | |||
| # pass # already gone – benign collision | |||
| # # ----------------------------------------------------------------- | |||
| # edit_counter += needed | |||
| # if edit_counter % 1_000 == 0: | |||
| # logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples)) | |||
| # raise RuntimeError("Exceeded max_iters without hitting target band.") | |||
| import os | |||
| import random | |||
| import threading | |||
| from collections import defaultdict | |||
| from typing import Dict, List, Tuple | |||
| # -------------------------------------------------------------------- | |||
| # Logging helper (replace with your own if you prefer) -------------- | |||
| # -------------------------------------------------------------------- | |||
| from tools import get_pretty_logger | |||
| logging = get_pretty_logger(__name__) | |||
| # -------------------------------------------------------------------- | |||
| # Original edge‑editing primitives (unchanged) ---------------------- | |||
| # -------------------------------------------------------------------- | |||
| def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||
| for _ in range(1000): | |||
| h = rng.randrange(n_ent) | |||
| t = rng.randrange(n_ent) | |||
| if colours[h] != colours[t]: | |||
| r = rng.randrange(n_rel) | |||
| triples.append((h, r, t)) | |||
| out_adj[h].append((r, t)) | |||
| in_adj[t].append((r, h)) | |||
| return | |||
| def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||
| σ = rng.choice(colours) | |||
| τ = rng.choice(colours) | |||
| h = colours.index(σ) | |||
| t = colours.index(τ) | |||
| triples.append((h, 0, t)) | |||
| out_adj[h].append((0, t)) | |||
| in_adj[t].append((0, h)) | |||
| def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
| cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||
| if cand: | |||
| idx = rng.choice(cand) | |||
| h, r, t = triples.pop(idx) | |||
| out_adj[h].remove((r, t)) | |||
| in_adj[t].remove((r, h)) | |||
| def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
| sig_rel = defaultdict(set) | |||
| for h, r, t in triples: | |||
| σ, τ = colours[h], colours[t] | |||
| sig_rel[(σ, τ)].add(r) | |||
| cand = [] | |||
| for i, (h, r, t) in enumerate(triples): | |||
| if len(sig_rel[(colours[h], colours[t])]) == 1: | |||
| cand.append(i) | |||
| if cand: | |||
| idx = rng.choice(cand) | |||
| h, r, t = triples.pop(idx) | |||
| out_adj[h].remove((r, t)) | |||
| in_adj[t].remove((r, h)) | |||
| def _worker( | |||
| *, | |||
| worker_id: int, | |||
| rng: random.Random, | |||
| triples: List[Tuple[int, int, int]], | |||
| out_adj: Dict[int, List[Tuple[int, int]]], | |||
| in_adj: Dict[int, List[Tuple[int, int]]], | |||
| colours: List[int], | |||
| n_ent: int, | |||
| n_rel: int, | |||
| depth: int, | |||
| c: float, | |||
| lo: float, | |||
| hi: float, | |||
| max_iters: int, | |||
| state_lock: threading.Lock, | |||
| stop_event: threading.Event, | |||
| ): | |||
| """One thread: mutate the *shared* structures until success/stop.""" | |||
| for it in range(max_iters): | |||
| if stop_event.is_set(): | |||
| return # someone else finished ─ exit early | |||
| with state_lock: # protect the shared graph | |||
| # if lo <= c <= hi: | |||
| # logging.info( | |||
| # "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)", | |||
| # worker_id, | |||
| # it, | |||
| # c, | |||
| # len(triples), | |||
| # ) | |||
| # stop_event.set() | |||
| # return | |||
| # Choose and apply one edit ----------------------------- | |||
| if c < lo: # need ↑ CREC | |||
| if rng.random() < 0.5: | |||
| rem_det_edge(triples, out_adj, in_adj, colours, depth, rng) | |||
| else: | |||
| add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng) | |||
| elif rng.random() < 0.5: | |||
| rem_div_edge(triples, out_adj, in_adj, colours, depth, rng) | |||
| else: | |||
| add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng) | |||
| logging.warning("[worker %d] reached max_iters", worker_id) | |||
| stop_event.set() | |||
| return | |||
| # -------------------------------------------------------------------- | |||
| # Public API -------------------------------------------------------- | |||
| # -------------------------------------------------------------------- | |||
| def fast_tune_crec( | |||
| triples: List, | |||
| colours: List, | |||
| n_ent: int, | |||
| n_rel: int, | |||
| depth: int, | |||
| c: float, | |||
| lo: float, | |||
| hi: float, | |||
| max_iters: int = 1000, | |||
| max_workers: int | None = None, | |||
| seeds: List[int] | None = None, | |||
| ) -> List[Tuple[int, int, int]]: | |||
| """Tune WL‑CREC with *shared* triples using multiple threads. | |||
| Returns the **same list instance** that was passed in – already | |||
| modified in place by the winning thread. | |||
| """ | |||
| if max_workers is None: | |||
| # max_workers = os.cpu_count() or 4 | |||
| max_workers = 4 | |||
| if seeds is None: | |||
| seeds = [42 + i for i in range(max_workers)] | |||
| assert len(seeds) >= max_workers, "Need at least one seed per worker" | |||
| # Prepare adjacency once (shared) -------------------------------- | |||
| out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||
| in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||
| for h, r, t in triples: | |||
| out_adj[h].append((r, t)) | |||
| in_adj[t].append((r, h)) | |||
| state_lock = threading.Lock() | |||
| stop_event = threading.Event() | |||
| logging.info( | |||
| "Launching %d threads on shared triples (target %.3f–%.3f)", | |||
| max_workers, | |||
| lo, | |||
| hi, | |||
| ) | |||
| threads: List[threading.Thread] = [] | |||
| for wid in range(max_workers): | |||
| t = threading.Thread( | |||
| name=f"tune‑crec‑worker‑{wid}", | |||
| target=_worker, | |||
| kwargs=dict( | |||
| worker_id=wid, | |||
| rng=random.Random(seeds[wid]), | |||
| triples=triples, | |||
| out_adj=out_adj, | |||
| in_adj=in_adj, | |||
| colours=colours, | |||
| n_ent=n_ent, | |||
| n_rel=n_rel, | |||
| depth=depth, | |||
| c=c, | |||
| lo=lo, | |||
| hi=hi, | |||
| max_iters=max_iters, | |||
| state_lock=state_lock, | |||
| stop_event=stop_event, | |||
| ), | |||
| daemon=False, | |||
| ) | |||
| threads.append(t) | |||
| t.start() | |||
| for t in threads: | |||
| t.join() | |||
| if not stop_event.is_set(): | |||
| raise RuntimeError("No thread converged – try increasing max_iters or widen band") | |||
| return triples | |||
| @@ -0,0 +1,106 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| import random | |||
| from collections import defaultdict, deque | |||
| from math import isclose | |||
| from typing import List | |||
| import torch | |||
| from data.kg_dataset import KGDataset | |||
| from metrics.wlcrec import WLCREC | |||
| from tools import get_pretty_logger | |||
| logging = get_pretty_logger(__name__) | |||
| # -------------------------------------------------------------------- | |||
| # Radius-k patch sampler | |||
| # -------------------------------------------------------------------- | |||
| def radius_patch_sample( | |||
| all_triples: List, | |||
| ratio: float, | |||
| radius: int, | |||
| rng: random.Random, | |||
| ) -> List: | |||
| """ | |||
| Take ~ratio fraction of triples by picking random seeds and | |||
| keeping *all* triples within ≤ radius hops (undirected) of them. | |||
| """ | |||
| n_total = len(all_triples) | |||
| target = int(ratio * n_total) | |||
| # Build entity → triple-indices adjacency for BFS | |||
| ent2trip = defaultdict(list) | |||
| for idx, (h, _, t) in enumerate(all_triples): | |||
| ent2trip[h].append(idx) | |||
| ent2trip[t].append(idx) | |||
| chosen_idx: set[int] = set() | |||
| visited_ent = set() | |||
| while len(chosen_idx) < target: | |||
| seed_idx = rng.randrange(n_total) | |||
| if seed_idx in chosen_idx: | |||
| continue | |||
| # BFS over entities | |||
| queue = deque() | |||
| for e in all_triples[seed_idx][::2]: # subject & object | |||
| if e not in visited_ent: | |||
| queue.append((e, 0)) | |||
| visited_ent.add(e) | |||
| while queue: | |||
| ent, dist = queue.popleft() | |||
| for tidx in ent2trip[ent]: | |||
| if tidx not in chosen_idx: | |||
| chosen_idx.add(tidx) | |||
| if dist < radius: | |||
| for tidx in ent2trip[ent]: | |||
| h, _, t = all_triples[tidx] | |||
| for nb in (h, t): | |||
| if nb not in visited_ent: | |||
| visited_ent.add(nb) | |||
| queue.append((nb, dist + 1)) | |||
| # guard against infinite loop on very small ratio | |||
| if isclose(ratio, 0.0) and chosen_idx: | |||
| break | |||
| return [all_triples[i] for i in chosen_idx] | |||
| # -------------------------------------------------------------------- | |||
| # Main driver: repeated sampling until WL-CREC band hit | |||
| # -------------------------------------------------------------------- | |||
| def find_sample( | |||
| all_triples: List, | |||
| n_ent: int, | |||
| n_rel: int, | |||
| depth: int, | |||
| ratio: float, | |||
| radius: int, | |||
| lower: float, | |||
| upper: float, | |||
| max_tries: int, | |||
| seed: int, | |||
| ) -> List: | |||
| for trial in range(1, max_tries + 1): | |||
| rng = random.Random(seed + trial) # fresh randomness each try | |||
| sample = radius_patch_sample(all_triples, ratio, radius, rng) | |||
| sample_ds = KGDataset( | |||
| triples_factory=None, | |||
| triples=torch.tensor(sample, dtype=torch.long), | |||
| num_entities=n_ent, | |||
| num_relations=n_rel, | |||
| split="train", | |||
| ) | |||
| wl_crec = WLCREC(sample_ds) | |||
| crec_val = wl_crec.compute(depth)[-2] | |||
| logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val) | |||
| if lower <= crec_val <= upper: | |||
| logging.info("Success after %d tries.", trial) | |||
| return sample | |||
| raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.") | |||
| @@ -0,0 +1,226 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| import random | |||
| from math import log | |||
| from typing import Dict, List, Tuple | |||
| import torch | |||
| from torch import Tensor | |||
| from tools import get_pretty_logger | |||
| logging = get_pretty_logger(__name__) | |||
| Entity = int | |||
| Relation = int | |||
| Triple = Tuple[Entity, Relation, Entity] | |||
| Color = int | |||
| # --------------------------------------------------------------------- | |||
| # Weisfeiler–Lehman colouring (GPU friendly) | |||
| # --------------------------------------------------------------------- | |||
| def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]: | |||
| """ | |||
| GPU implementation of classic Weisfeiler‑Lehman refinement. | |||
| 1. Each node starts with colour 0. | |||
| 2. At iteration h we hash the multiset of | |||
| ⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h)) | |||
| signatures into new colour ids. | |||
| 3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and | |||
| immaterial for CREC (only *relative* colours matter). | |||
| """ | |||
| h, r, t = triples.unbind(dim=1) # (|T|,) | |||
| colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0 | |||
| # pre‑compute 64‑bit mix constants once | |||
| mix_r = ( | |||
| torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long) | |||
| * 1_146_189_683_093_321_123 | |||
| ) | |||
| mix_c = 8_636_673_225_737_527_201 | |||
| for _ in range(depth): | |||
| prev = colours[-1] # (|V|,) | |||
| # signatures for outgoing edges | |||
| sig_out = mix_r[r] ^ (prev[t] * mix_c) | |||
| # signatures for incoming edges | |||
| sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15 | |||
| # bucket signatures back to the source/target nodes | |||
| col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out) | |||
| col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in) | |||
| # combine ↓ and ↑ multiset hashes with current colour | |||
| raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1) | |||
| # re‑map to dense consecutive ids with torch.unique | |||
| uniq, new = torch.unique(raw, sorted=True, return_inverse=True) | |||
| colours.append(new) | |||
| return colours # length = depth+1, each (|V|,) long | |||
| # --------------------------------------------------------------------- | |||
| # WL‑CREC metric | |||
| # --------------------------------------------------------------------- | |||
| def wl_crec_gpu( | |||
| triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device | |||
| ) -> float: | |||
| log_n = log(max(n_ent, 2)) | |||
| # ------- pattern‑diversity C ------------------------------------- | |||
| C = 0.0 | |||
| for c in colours: # (|V|,) | |||
| # histogram via bincount on GPU | |||
| hist = torch.bincount(c).float() | |||
| p = hist / n_ent | |||
| C += -(p * torch.log(p.clamp_min(1e-30))).sum().item() | |||
| C /= (depth + 1) * log_n | |||
| # ------- residual entropy H_c ------------------------------------ | |||
| h, r, t = triples.unbind(dim=1) | |||
| sigma, tau = colours[depth][h], colours[depth][t] # (|T|,) | |||
| # encode (sigma,tau) pairs into a single 64‑bit key | |||
| sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64) | |||
| # total count per signature | |||
| sig_unique, sig_inv, sig_counts = torch.unique( | |||
| sig_keys, return_inverse=True, return_counts=True, sorted=False | |||
| ) | |||
| # build 2‑D contingency table counts[(sigma,tau), r] | |||
| m = triples.size(0) | |||
| keys_2d = sig_inv * n_rel + r | |||
| _, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False) | |||
| # rel_counts is aligned with the *compact* key list; rebuild dense tensor | |||
| rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long) | |||
| rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long)) | |||
| # conditional entropy per (sigma,tau) | |||
| p_r = rc_dense.float() / sig_counts.unsqueeze(1).float() | |||
| inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,) | |||
| Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2)) | |||
| return C * Hc | |||
| # --------------------------------------------------------------------- | |||
| # delta‑entropy helpers (still CPU side for clarity; cost is negligible) | |||
| # --------------------------------------------------------------------- | |||
| # The deterministic delta‑formulas depend on small dictionaries and are evaluated | |||
| # on ≤256 candidate edges each iteration – copy them from the original script. | |||
| inner_entropy = ... # unchanged | |||
| delta_h_cond_remove = ... # unchanged | |||
| delta_h_cond_add = ... # unchanged | |||
| # --------------------------------------------------------------------- | |||
| # Greedy delta‑search driver | |||
| # --------------------------------------------------------------------- | |||
| def greedy_delta_tune_gpu( | |||
| triples: List[Triple], | |||
| n_ent: int, | |||
| n_rel: int, | |||
| depth: int, | |||
| lower: float, | |||
| upper: float, | |||
| device: torch.device, | |||
| max_iters: int = 40_000, | |||
| sample_size: int = 256, | |||
| seed: int = 0, | |||
| ) -> List[Triple]: | |||
| # book‑keeping in Python lists (cheap); heavy math in torch on GPU | |||
| rng = random.Random(seed) | |||
| # cache torch version of triples for fast metric eval | |||
| def triples_to_tensor(ts: List[Triple]) -> Tensor: | |||
| return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False) | |||
| for it in range(max_iters): | |||
| tri_tensor = triples_to_tensor(triples) | |||
| colours = wl_colours_gpu(tri_tensor, n_ent, depth, device) | |||
| crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device) | |||
| if lower <= crec <= upper: | |||
| logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples)) | |||
| return triples | |||
| # ---------------------------------------------------------------- | |||
| # build signature statistics (CPU, cheap: O(|T|)) | |||
| # ---------------------------------------------------------------- | |||
| depth_col = colours[depth].cpu() | |||
| sig_cnt: Dict[Tuple[int, int], int] = {} | |||
| rel_cnt: Dict[Tuple[int, int, int], int] = {} | |||
| for h_idx, (h, r, t) in enumerate(triples): | |||
| sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||
| sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1 | |||
| rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1 | |||
| det_edges, div_edges = [], [] | |||
| for idx, (h, r, t) in enumerate(triples): | |||
| sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||
| if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1: | |||
| det_edges.append(idx) | |||
| if sigma != tau: | |||
| div_edges.append(idx) | |||
| # ---------------------------------------------------------------- | |||
| # candidate generation + best edit selection | |||
| # ---------------------------------------------------------------- | |||
| total = len(triples) | |||
| target_high = crec < lower # need to raise WL‑CREC? | |||
| candidates = [] | |||
| if target_high and det_edges: | |||
| rng.shuffle(det_edges) | |||
| for idx in det_edges[:sample_size]: | |||
| h, r, t = triples[idx] | |||
| sig = (int(depth_col[h]), int(depth_col[t])) | |||
| delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||
| if delta > 0: | |||
| candidates.append(("remove", idx, delta)) | |||
| elif not target_high and det_edges: | |||
| rng.shuffle(det_edges) | |||
| for idx in det_edges[:sample_size]: | |||
| h, r, t = triples[idx] | |||
| sig = (int(depth_col[h]), int(depth_col[t])) | |||
| delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||
| if delta < 0: | |||
| candidates.append(("add", (h, r, t), delta)) | |||
| # fall‑back heuristics | |||
| if not candidates: | |||
| if target_high and div_edges: | |||
| idx = rng.choice(div_edges) | |||
| candidates.append(("remove", idx, 1e-9)) | |||
| elif not target_high: | |||
| idx = rng.choice(det_edges) if det_edges else rng.randrange(total) | |||
| h, r, t = triples[idx] | |||
| candidates.append(("add", (h, r, t), -1e-9)) | |||
| best = ( | |||
| max(candidates, key=lambda x: x[2]) | |||
| if target_high | |||
| else min(candidates, key=lambda x: x[2]) | |||
| ) | |||
| # apply edit | |||
| if best[0] == "remove": | |||
| triples.pop(best[1]) | |||
| else: | |||
| triples.append(best[1]) | |||
| if (it + 1) % 1_000 == 0: | |||
| logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples)) | |||
| raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.") | |||
| @@ -0,0 +1,30 @@ | |||
| import torch | |||
| from models.base_model import ERModel | |||
| def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |||
| return (scores > scores[target]).sum() + 1 | |||
| def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor: | |||
| h, r, t = triples[:, 0], triples[:, 1], triples[:, 2] | |||
| E = model.num_entities | |||
| ranks = [] | |||
| with torch.no_grad(): | |||
| for hi, ri, ti in zip(h, r, t): | |||
| tails = torch.arange(E, device=triples.device) | |||
| triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1) | |||
| scores = model(triples) | |||
| ranks.append(_rank(scores, ti)) | |||
| return torch.tensor(ranks, device=triples.device, dtype=torch.float32) | |||
| def mrr(r: torch.Tensor) -> torch.Tensor: | |||
| return (1 / r).mean() | |||
| def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor: | |||
| return (r <= k).float().mean() | |||
| @@ -0,0 +1,237 @@ | |||
| from __future__ import annotations | |||
| import math | |||
| from collections import Counter, defaultdict | |||
| from typing import TYPE_CHECKING, Dict, List, Tuple | |||
| import torch | |||
| from .base_metric import BaseMetric | |||
| if TYPE_CHECKING: | |||
| from data.kg_dataset import KGDataset | |||
| def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||
| """Return Shannon entropy (nats) of a colour distribution of length *n*.""" | |||
| ent = 0.0 | |||
| for cnt in counter.values(): | |||
| if cnt: | |||
| p = cnt / n | |||
| ent -= p * math.log(p) | |||
| return ent | |||
| def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||
| """Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies.""" | |||
| if n <= 1: | |||
| return 0.0, 0.0 | |||
| H_plus_1 = len(entropies) | |||
| log_n = math.log(n) | |||
| c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1 | |||
| c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1 | |||
| return c_ratio, c_nwlec | |||
| def _relation_families( | |||
| triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9 | |||
| ) -> torch.Tensor: | |||
| """Return a LongTensor mapping each relation id → family id. | |||
| Two relations *r, r_inv* are put into the same family when **at least | |||
| `thresh` fraction** of the edges labelled *r* have a **single** reverse edge | |||
| labelled *r_inv*. | |||
| """ | |||
| assert 0.0 < thresh <= 1.0 | |||
| # Build map (h,t) -> list of relations on that edge direction. | |||
| edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list) | |||
| for h, r, t in triples.tolist(): | |||
| edge2rels[(h, t)].append(int(r)) | |||
| # Count reciprocal co‑occurrences (r, r_rev). | |||
| pair_counts: Dict[Tuple[int, int], int] = defaultdict(int) | |||
| rel_totals: List[int] = [0] * num_relations | |||
| for (h, t), rels in edge2rels.items(): | |||
| rev_rels = edge2rels.get((t, h)) | |||
| if not rev_rels: | |||
| continue | |||
| for r in rels: | |||
| rel_totals[r] += 1 | |||
| for r_rev in rev_rels: | |||
| pair_counts[(r, r_rev)] += 1 | |||
| # Proposed inverse mapping r -> r_inv. | |||
| proposed: List[int | None] = [None] * num_relations | |||
| for (r, r_rev), cnt in pair_counts.items(): | |||
| if cnt == 0: | |||
| continue | |||
| if cnt / rel_totals[r] >= thresh: | |||
| # majority of r's edges are reversed by r_rev | |||
| proposed[r] = r_rev | |||
| # Build family ids (transitive closure is overkill; data are simple). | |||
| family = list(range(num_relations)) | |||
| for r, rinv in enumerate(proposed): | |||
| if rinv is not None: | |||
| root = min(family[r], family[rinv]) | |||
| family[r] = family[rinv] = root | |||
| return torch.tensor(family, dtype=torch.long) | |||
| def _conditional_relation_entropy( | |||
| colours: List[int], | |||
| triples: torch.Tensor, | |||
| family_map: torch.Tensor, | |||
| ) -> float: | |||
| counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int)) | |||
| m = int(triples.size(0)) | |||
| fam = family_map # alias for speed | |||
| for h, r, t in triples.tolist(): | |||
| # key = (colours[h], colours[t]) # ordered key (direction‑aware) | |||
| key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic) | |||
| counts[key][int(fam[r])] += 1 # ▼ use family id | |||
| h_cond = 0.0 | |||
| for rel_counts in counts.values(): | |||
| total_s = sum(rel_counts.values()) | |||
| P_s = total_s / m | |||
| inv_total = 1.0 / total_s | |||
| for cnt in rel_counts.values(): | |||
| p_rs = cnt * inv_total | |||
| h_cond -= P_s * p_rs * math.log(p_rs) | |||
| return h_cond | |||
| class WLCREC(BaseMetric): | |||
| """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||
| def __init__(self, dataset: KGDataset) -> None: | |||
| super().__init__(dataset) | |||
| def compute( | |||
| self, | |||
| H: int = 3, | |||
| cond_h: int = 1, | |||
| inv_thresh: float = 0.9, | |||
| return_full: bool = False, | |||
| progress: bool = True, | |||
| ) -> float | Tuple[float, List[float]]: | |||
| """Compute WL-entropy scores and the composite difficulty metric. | |||
| Parameters | |||
| ---------- | |||
| dataset | |||
| Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``. | |||
| H | |||
| WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers. | |||
| return_full | |||
| Also return the list of per-layer entropies. | |||
| progress | |||
| Print textual progress bar. | |||
| Returns | |||
| ------- | |||
| avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC | |||
| *Six* scalars (floats). If ``return_full=True`` an extra list of | |||
| layer entropies is appended. | |||
| """ | |||
| # compute relation family map once | |||
| family_map = _relation_families( | |||
| self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh | |||
| ) | |||
| # ─── WL iterations ──────────────────────────────────────────────────── | |||
| n = self.dataset.num_entities | |||
| colour: List[int] = [1] * n # colour 1 for everyone | |||
| entropies: List[float] = [] | |||
| colours_per_h: List[List[int]] = [] | |||
| def _print_prog(i: int) -> None: | |||
| if progress: | |||
| print(f"\rWL iteration {i}/{H}", end="", flush=True) | |||
| for h in range(H + 1): | |||
| _print_prog(h) | |||
| colours_per_h.append(colour.copy()) | |||
| # entropy of current colouring | |||
| freq = Counter(colour) | |||
| entropies.append(_entropy_from_counter(freq, n)) | |||
| if h == H: | |||
| break | |||
| # refine colours | |||
| bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||
| next_colour: List[int] = [0] * n | |||
| for v in range(n): | |||
| T: List[Tuple[int, int, int]] = [] | |||
| for r, t in self.out_adj[v]: | |||
| T.append((r, 0, colour[t])) # outgoing | |||
| for r, h_ in self.in_adj[v]: | |||
| T.append((r, 1, colour[h_])) # incoming | |||
| T.sort() | |||
| key = (colour[v], tuple(T)) | |||
| if key not in bucket: | |||
| bucket[key] = len(bucket) + 1 | |||
| next_colour[v] = bucket[key] | |||
| colour = next_colour | |||
| _print_prog(H) | |||
| if progress: | |||
| print() | |||
| # ─── Normalised diversity ──────────────────────────────────────────── | |||
| avg_entropy = sum(entropies) / len(entropies) | |||
| c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||
| # ─── Residual relation entropy ─────────────────────────────────────── | |||
| h_cond = _conditional_relation_entropy( | |||
| colours_per_h[cond_h], self.dataset.triples, family_map | |||
| ) | |||
| # ─── Composite difficulty - Eq. (1) ────────────────────────────────── | |||
| d_ratio = c_ratio * h_cond | |||
| d_nwlec = c_nwlec * h_cond | |||
| result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec) | |||
| if return_full: | |||
| result += (entropies,) | |||
| return result | |||
| # -------------------------------------------------------------------- | |||
| # Weisfeiler–Lehman colours (unchanged) | |||
| # -------------------------------------------------------------------- | |||
| def wl_colours(self, depth: int) -> List[List[int]]: | |||
| triples = self.dataset.triples.tolist() | |||
| out_adj, in_adj = defaultdict(list), defaultdict(list) | |||
| for h, r, t in triples: | |||
| out_adj[h].append((r, t)) | |||
| in_adj[t].append((r, h)) | |||
| n_ent = self.dataset.num_entities | |||
| colours = [[0] * n_ent] # round-0 | |||
| for h in range(1, depth + 1): | |||
| prev, nxt, sig2c = colours[-1], [0] * n_ent, {} | |||
| fresh = 0 | |||
| for v in range(n_ent): | |||
| neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||
| ("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||
| ] | |||
| neigh.sort() | |||
| sig = (prev[v], tuple(neigh)) | |||
| if sig not in sig2c: | |||
| sig2c[sig] = fresh | |||
| fresh += 1 | |||
| nxt[v] = sig2c[sig] | |||
| colours.append(nxt) | |||
| return colours | |||
| @@ -0,0 +1,135 @@ | |||
| from __future__ import annotations | |||
| import math | |||
| from collections import Counter | |||
| from typing import TYPE_CHECKING, Dict, List, Tuple | |||
| from .base_metric import BaseMetric | |||
| if TYPE_CHECKING: | |||
| from data.kg_dataset import KGDataset | |||
| def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||
| """Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts.""" | |||
| ent = 0.0 | |||
| for count in counter.values(): | |||
| if count: # avoid log(0) | |||
| p = count / n | |||
| ent -= p * math.log(p) | |||
| return ent | |||
| def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||
| """Return (C_ratio, C_NWLEC) given *per-iteration* entropies.""" | |||
| if n <= 1: | |||
| # Degenerate graph. | |||
| return 0.0, 0.0 | |||
| H_plus_1 = len(entropies) | |||
| # Eq. (1): simple ratio normalisation. | |||
| c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1 | |||
| # Eq. (2): effective-colour NWLEC. | |||
| k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies] | |||
| c_nwlec = sum(k_terms) / H_plus_1 | |||
| return c_ratio, c_nwlec | |||
| class WLEC(BaseMetric): | |||
| """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||
| def __init__(self, dataset: KGDataset) -> None: | |||
| super().__init__(dataset) | |||
| def compute( | |||
| self, | |||
| H: int = 3, | |||
| *, | |||
| return_full: bool = False, | |||
| progress: bool = True, | |||
| ) -> float | Tuple[float, List[float]]: | |||
| """Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC). | |||
| Parameters | |||
| ---------- | |||
| dataset: | |||
| Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``), | |||
| ``num_entities`` and ``num_relations``. | |||
| H: | |||
| Number of refinement iterations **after** the initial colouring (so the | |||
| colours are updated **H** times and the entropy is measured **H+1** times). | |||
| return_full: | |||
| If *True*, additionally return the list of entropies for each depth. | |||
| progress: | |||
| Whether to print a mini progress bar (no external dependency). | |||
| Returns | |||
| ------- | |||
| average_entropy or (average_entropy, entropies) | |||
| Average of the stored entropies, and optionally the list itself. | |||
| """ | |||
| n = int(self.num_entities) | |||
| if n == 0: | |||
| msg = "Dataset appears to contain zero entities." | |||
| raise ValueError(msg) | |||
| # Colour assignment - we keep it as a *list* of ints for cheap hashing. | |||
| # Start with colour 1 for every entity (index 0 unused). | |||
| colour: List[int] = [1] * n | |||
| entropies: List[float] = [] | |||
| # Optional poor-man's progress bar. | |||
| def _print_prog(i: int) -> None: | |||
| if progress: | |||
| print(f"\rIteration {i}/{H}", end="", flush=True) | |||
| for h in range(H + 1): | |||
| _print_prog(h) | |||
| # ------------------------------------------------------------------ | |||
| # Step 1: measure entropy of current colouring ---------------------- | |||
| # ------------------------------------------------------------------ | |||
| freq = Counter(colour) | |||
| ent = _entropy_from_counter(freq, n) | |||
| entropies.append(ent) | |||
| if h == H: | |||
| break # We have reached the requested depth - stop here. | |||
| # ------------------------------------------------------------------ | |||
| # Step 2: create refined colours ----------------------------------- | |||
| # ------------------------------------------------------------------ | |||
| bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||
| next_colour: List[int] = [0] * n | |||
| for v in range(n): | |||
| # Build the multiset T containing outgoing and incoming features. | |||
| T: List[Tuple[int, int, int]] = [] | |||
| for r, t in self.out_adj[v]: | |||
| T.append((r, 0, colour[t])) # 0 = outgoing (\u2193) | |||
| for r, h_ in self.in_adj[v]: | |||
| T.append((r, 1, colour[h_])) # 1 = incoming (\u2191) | |||
| T.sort() # canonical ordering | |||
| key = (colour[v], tuple(T)) | |||
| if key not in bucket: | |||
| bucket[key] = len(bucket) + 1 # new colour ID (start at 1) | |||
| next_colour[v] = bucket[key] | |||
| colour = next_colour # move to next iteration | |||
| _print_prog(H) | |||
| if progress: | |||
| print() # newline after progress bar | |||
| average_entropy = sum(entropies) / len(entropies) | |||
| c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||
| if return_full: | |||
| return average_entropy, entropies, c_ratio, c_nwlec | |||
| return average_entropy, c_ratio, c_nwlec | |||
| @@ -0,0 +1,127 @@ | |||
| from __future__ import annotations | |||
| from abc import ABC, abstractmethod | |||
| from typing import TYPE_CHECKING | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn.functional import log_softmax, softmax | |||
| if TYPE_CHECKING: | |||
| from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target | |||
| class ERModel(nn.Module, ABC): | |||
| """Base class for knowledge graph models.""" | |||
| def __init__(self, num_entities: int, num_relations: int, dim: int) -> None: | |||
| super().__init__() | |||
| self.num_entities = num_entities | |||
| self.num_relations = num_relations | |||
| self.dim = dim | |||
| self.inner_model = None | |||
| @property | |||
| def device(self) -> torch.device: | |||
| """Return the device of the model.""" | |||
| return next(self.inner_model.parameters()).device | |||
| @abstractmethod | |||
| def reset_parameters(self, *args, **kwargs) -> None: | |||
| """Reset the parameters of the model.""" | |||
| def predict( | |||
| self, | |||
| hrt_batch: MappedTriples, | |||
| target: Target, | |||
| full_batch: bool = True, | |||
| ids: LongTensor | None = None, | |||
| **kwargs, | |||
| ) -> FloatTensor: | |||
| if not self.inner_model: | |||
| msg = ( | |||
| "Inner model is not set. Please initialize the inner model before calling predict." | |||
| ) | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| return self.inner_model.predict( | |||
| hrt_batch=hrt_batch, | |||
| target=target, | |||
| full_batch=full_batch, | |||
| ids=ids, | |||
| **kwargs, | |||
| ) | |||
| def forward( | |||
| self, | |||
| triples: MappedTriples, | |||
| slice_size: int | None = None, | |||
| slice_dim: int = 0, | |||
| *, | |||
| mode: InductiveMode | None, | |||
| ) -> FloatTensor: | |||
| h_indices = triples[:, 0] | |||
| r_indices = triples[:, 1] | |||
| t_indices = triples[:, 2] | |||
| if not self.inner_model: | |||
| msg = ( | |||
| "Inner model is not set. Please initialize the inner model before calling forward." | |||
| ) | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| return self.inner_model.forward( | |||
| h_indices=h_indices, | |||
| r_indices=r_indices, | |||
| t_indices=t_indices, | |||
| slice_size=slice_size, | |||
| slice_dim=slice_dim, | |||
| mode=mode, | |||
| ) | |||
| @torch.inference_mode() | |||
| def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r) | |||
| self, | |||
| hr_batch: LongTensor, # (B, 2) : (h,r) | |||
| *, | |||
| slice_size: int | None = None, | |||
| mode: InductiveMode | None = None, | |||
| log: bool = False, | |||
| ) -> FloatTensor: # (B, |E|) | |||
| scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode) | |||
| return log_softmax(scores, -1) if log else softmax(scores, -1) | |||
| # ----------------------------------------------------------------------- # | |||
| # (1) HEAD-given-(relation, tail) → pθ(h | r,t) # | |||
| # ----------------------------------------------------------------------- # | |||
| @torch.inference_mode() | |||
| def head_distribution( | |||
| self, | |||
| rt_batch: LongTensor, # (B, 2) : (r,t) | |||
| *, | |||
| slice_size: int | None = None, | |||
| mode: InductiveMode | None = None, | |||
| log: bool = False, | |||
| ) -> FloatTensor: # (B, |E|) | |||
| scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode) | |||
| return log_softmax(scores, -1) if log else softmax(scores, -1) | |||
| # ----------------------------------------------------------------------- # | |||
| # (2) RELATION-given-(head, tail) → pθ(r | h,t) # | |||
| # ----------------------------------------------------------------------- # | |||
| @torch.inference_mode() | |||
| def relation_distribution( | |||
| self, | |||
| ht_batch: LongTensor, # (B, 2) : (h,t) | |||
| *, | |||
| slice_size: int | None = None, | |||
| mode: InductiveMode | None = None, | |||
| log: bool = False, | |||
| ) -> FloatTensor: # (B, |R|) | |||
| scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode) | |||
| return log_softmax(scores, -1) if log else softmax(scores, -1) | |||
| @@ -0,0 +1,70 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| import torch | |||
| from pykeen.models import TransE as PyKEENTransE | |||
| from torch import nn | |||
| from models.base_model import ERModel | |||
| if TYPE_CHECKING: | |||
| from pykeen.typing import MappedTriples | |||
| class TransE(ERModel): | |||
| """ | |||
| TransE model for knowledge graph embedding. | |||
| score = ||h + r - t|| | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_entities: int, | |||
| num_relations: int, | |||
| triples_factory: MappedTriples, | |||
| dim: int = 200, | |||
| sf_norm: int = 1, | |||
| p_norm: bool = True, | |||
| margin: float | None = None, | |||
| epsilon: float | None = None, | |||
| device: torch.device = torch.device("cpu"), | |||
| ) -> None: | |||
| super().__init__(num_entities, num_relations, dim) | |||
| self.dim = dim | |||
| self.sf_norm = sf_norm | |||
| self.p_norm = p_norm | |||
| self.margin = margin | |||
| self.epsilon = epsilon | |||
| self.inner_model = PyKEENTransE( | |||
| triples_factory=triples_factory, | |||
| embedding_dim=dim, | |||
| scoring_fct_norm=sf_norm, | |||
| power_norm=p_norm, | |||
| random_seed=42, | |||
| ).to(device) | |||
| def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||
| # Parameter initialization | |||
| if margin is None or epsilon is None: | |||
| # If no margin/epsilon are provided, use Xavier uniform initialization | |||
| nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||
| nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||
| else: | |||
| # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||
| self.embedding_range = nn.Parameter( | |||
| torch.Tensor([(margin + epsilon) / self.dim]), | |||
| requires_grad=False, | |||
| ) | |||
| nn.init.uniform_( | |||
| tensor=self.ent_embeddings.weight.data, | |||
| a=-self.embedding_range.item(), | |||
| b=self.embedding_range.item(), | |||
| ) | |||
| nn.init.uniform_( | |||
| tensor=self.rel_embeddings.weight.data, | |||
| a=-self.embedding_range.item(), | |||
| b=self.embedding_range.item(), | |||
| ) | |||
| @@ -0,0 +1,76 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| import torch | |||
| from pykeen.losses import MarginRankingLoss | |||
| from pykeen.models import TransH as PyKEENTransH | |||
| from pykeen.regularizers import LpRegularizer | |||
| from torch import nn | |||
| from models.base_model import ERModel | |||
| if TYPE_CHECKING: | |||
| from pykeen.typing import MappedTriples | |||
| class TransH(ERModel): | |||
| """ | |||
| TransH model for knowledge graph embedding. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_entities: int, | |||
| num_relations: int, | |||
| triples_factory: MappedTriples, | |||
| dim: int = 200, | |||
| p_norm: bool = True, | |||
| margin: float | None = None, | |||
| epsilon: float | None = None, | |||
| device: torch.device = torch.device("cpu"), | |||
| **kwargs, | |||
| ) -> None: | |||
| super().__init__(num_entities, num_relations, dim) | |||
| self.dim = dim | |||
| self.p_norm = p_norm | |||
| self.margin = margin | |||
| self.epsilon = epsilon | |||
| self.inner_model = PyKEENTransH( | |||
| triples_factory=triples_factory, | |||
| embedding_dim=dim, | |||
| scoring_fct_norm=1, | |||
| loss=self.loss, | |||
| power_norm=p_norm, | |||
| entity_initializer="xavier_uniform_", # good default | |||
| relation_initializer="xavier_uniform_", | |||
| random_seed=42, | |||
| ).to(device) | |||
| @property | |||
| def loss(self) -> torch.Tensor: | |||
| return MarginRankingLoss(margin=self.margin, reduction="mean") | |||
| def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||
| # Parameter initialization | |||
| if margin is None or epsilon is None: | |||
| # If no margin/epsilon are provided, use Xavier uniform initialization | |||
| nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||
| nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||
| else: | |||
| # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||
| self.embedding_range = nn.Parameter( | |||
| torch.Tensor([(margin + epsilon) / self.dim]), | |||
| requires_grad=False, | |||
| ) | |||
| nn.init.uniform_( | |||
| tensor=self.ent_embeddings.weight.data, | |||
| a=-self.embedding_range.item(), | |||
| b=self.embedding_range.item(), | |||
| ) | |||
| nn.init.uniform_( | |||
| tensor=self.rel_embeddings.weight.data, | |||
| a=-self.embedding_range.item(), | |||
| b=self.embedding_range.item(), | |||
| ) | |||
| @@ -0,0 +1,79 @@ | |||
| from __future__ import annotations | |||
| from typing import TYPE_CHECKING | |||
| import torch | |||
| from pykeen.losses import MarginRankingLoss | |||
| from pykeen.models import TransR as PyKEENTransR | |||
| from torch import nn | |||
| from models.base_model import ERModel | |||
| if TYPE_CHECKING: | |||
| from pykeen.typing import MappedTriples | |||
| class TransR(ERModel): | |||
| """ | |||
| TransR model for knowledge graph embedding. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_entities: int, | |||
| num_relations: int, | |||
| triples_factory: MappedTriples, | |||
| entity_dim: int = 200, | |||
| relation_dim: int = 30, | |||
| p_norm: bool = True, | |||
| p_norm_value: int = 2, | |||
| margin: float | None = None, | |||
| epsilon: float | None = None, | |||
| device: torch.device = torch.device("cpu"), | |||
| **kwargs, | |||
| ) -> None: | |||
| super().__init__(num_entities, num_relations, entity_dim) | |||
| self.entity_dim = entity_dim | |||
| self.relation_dim = relation_dim | |||
| self.p_norm = p_norm | |||
| self.margin = margin | |||
| self.epsilon = epsilon | |||
| self.inner_model = PyKEENTransR( | |||
| triples_factory=triples_factory, | |||
| embedding_dim=entity_dim, | |||
| relation_dim=relation_dim, | |||
| scoring_fct_norm=1, | |||
| loss=self.loss, | |||
| power_norm=p_norm, | |||
| entity_initializer="xavier_uniform_", | |||
| relation_initializer="xavier_uniform_", | |||
| random_seed=42, | |||
| ).to(device) | |||
| @property | |||
| def loss(self) -> torch.Tensor: | |||
| return MarginRankingLoss(margin=self.margin, reduction="mean") | |||
| def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||
| # Parameter initialization | |||
| if margin is None or epsilon is None: | |||
| # If no margin/epsilon are provided, use Xavier uniform initialization | |||
| nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||
| nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||
| else: | |||
| # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||
| self.embedding_range = nn.Parameter( | |||
| torch.Tensor([(margin + epsilon) / self.dim]), | |||
| requires_grad=False, | |||
| ) | |||
| nn.init.uniform_( | |||
| tensor=self.ent_embeddings.weight.data, | |||
| a=-self.embedding_range.item(), | |||
| b=self.embedding_range.item(), | |||
| ) | |||
| nn.init.uniform_( | |||
| tensor=self.rel_embeddings.weight.data, | |||
| a=-self.embedding_range.item(), | |||
| b=self.embedding_range.item(), | |||
| ) | |||
| @@ -0,0 +1,82 @@ | |||
| [tool.ruff] | |||
| # Exclude a variety of commonly ignored directories. | |||
| exclude = [ | |||
| ".bzr", | |||
| ".direnv", | |||
| ".eggs", | |||
| ".git", | |||
| ".git-rewrite", | |||
| ".hg", | |||
| ".ipynb_checkpoints", | |||
| ".mypy_cache", | |||
| ".nox", | |||
| ".pants.d", | |||
| ".pyenv", | |||
| ".pytest_cache", | |||
| ".pytype", | |||
| ".ruff_cache", | |||
| ".svn", | |||
| ".tox", | |||
| ".venv", | |||
| ".vscode", | |||
| "__pypackages__", | |||
| "_build", | |||
| "buck-out", | |||
| "build", | |||
| "dist", | |||
| "node_modules", | |||
| "site-packages", | |||
| "venv", | |||
| ] | |||
| respect-gitignore = true | |||
| # Black line length is 88 | |||
| line-length = 100 | |||
| indent-width = 4 | |||
| [tool.ruff.lint] | |||
| # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. | |||
| # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or | |||
| # McCabe complexity (`C901`) by default. | |||
| select = ["ALL"] | |||
| ignore = ["ANN002", "ANN003", "RET504", "ERA001", "E741", "N812", "D", "PLR0913", "T201", "FBT001", "FBT002", "PLW0602", "PLW0603", "PTH", "C901", "PLR2004", "UP006", "UP035", "G004", "N803", "N806", "ARG002"] | |||
| [tool.ruff.lint.per-file-ignores] | |||
| "__init__.py" = ["E", "F", "I", "N"] | |||
| # Allow fix for all enabled rules (when `--fix`) is provided. | |||
| fixable = ["ALL"] | |||
| unfixable = [] | |||
| # Allow unused variables when underscore-prefixed. | |||
| #dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" | |||
| [tool.ruff.format] | |||
| # Like Black, use double quotes for strings. | |||
| quote-style = "double" | |||
| # Like Black, indent with spaces, rather than tabs. | |||
| indent-style = "space" | |||
| # Like Black, respect magic trailing commas. | |||
| skip-magic-trailing-comma = false | |||
| # Like Black, automatically detect the appropriate line ending. | |||
| line-ending = "auto" | |||
| # Enable auto-formatting of code examples in docstrings. Markdown, | |||
| # reStructuredText code/literal blocks and doctests are all supported. | |||
| # | |||
| # This is currently disabled by default, but it is planned for this | |||
| # to be opt-out in the future. | |||
| docstring-code-format = false | |||
| # Set the line length limit used when formatting code snippets in | |||
| # docstrings. | |||
| # | |||
| # This only has an effect when the `docstring-code-format` setting is | |||
| # enabled. | |||
| docstring-code-line-length = "dynamic" | |||
| [tool.pyright] | |||
| typeCheckingMode = "off" | |||
| @@ -0,0 +1,28 @@ | |||
| #!/usr/bin/env bash | |||
| # install.sh – minimal Conda bootstrap with ~/.bashrc check | |||
| set -e | |||
| ENV=kg-env # must match your environment.yml | |||
| YAML=kg_env.yaml # assumed to be in the current dir | |||
| # ── 1. try to load an existing conda from ~/.bashrc ────────────────────── | |||
| [ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" || true | |||
| # ── 2. ensure conda exists ─────────────────────────────────────────────── | |||
| if ! command -v conda &>/dev/null; then | |||
| echo "[install.sh] Conda not found → installing Miniconda." | |||
| curl -fsSL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh | |||
| bash miniconda.sh -b -p "$HOME/miniconda3" | |||
| rm miniconda.sh | |||
| source "$HOME/miniconda3/etc/profile.d/conda.sh" | |||
| conda init bash >/dev/null # so future shells have it | |||
| else | |||
| # conda exists; load its helper for this shell | |||
| source "$(conda info --base)/etc/profile.d/conda.sh" | |||
| fi | |||
| # ── 3. create or update the project env ────────────────────────────────── | |||
| conda env create -n "$ENV" -f "$YAML" \ | |||
| || conda env update -n "$ENV" -f "$YAML" --prune | |||
| echo "✔ Done. Activate with: conda activate $ENV" | |||
| @@ -0,0 +1,163 @@ | |||
| # environment.yml ── KG-Embeddings project | |||
| name: kg | |||
| channels: | |||
| - pytorch | |||
| - defaults | |||
| - nvidia | |||
| - conda-forge | |||
| - https://repo.anaconda.com/pkgs/main | |||
| - https://repo.anaconda.com/pkgs/r | |||
| dependencies: | |||
| - _libgcc_mutex=0.1=conda_forge | |||
| - _openmp_mutex=4.5=2_gnu | |||
| - aom=3.6.1=h59595ed_0 | |||
| - blas=1.0=mkl | |||
| - brotli-python=1.1.0=py311hfdbb021_2 | |||
| - bzip2=1.0.8=h4bc722e_7 | |||
| - ca-certificates=2024.12.14=hbcca054_0 | |||
| - certifi=2024.12.14=pyhd8ed1ab_0 | |||
| - cffi=1.17.1=py311hf29c0ef_0 | |||
| - charset-normalizer=3.4.1=pyhd8ed1ab_0 | |||
| - cpython=3.11.11=py311hd8ed1ab_1 | |||
| - cuda-cudart=12.4.127=0 | |||
| - cuda-cupti=12.4.127=0 | |||
| - cuda-libraries=12.4.1=0 | |||
| - cuda-nvrtc=12.4.127=0 | |||
| - cuda-nvtx=12.4.127=0 | |||
| - cuda-opencl=12.6.77=0 | |||
| - cuda-runtime=12.4.1=0 | |||
| - cuda-version=12.6=3 | |||
| - ffmpeg=4.4.2=gpl_hdf48244_113 | |||
| - filelock=3.16.1=pyhd8ed1ab_1 | |||
| - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 | |||
| - font-ttf-inconsolata=3.000=h77eed37_0 | |||
| - font-ttf-source-code-pro=2.038=h77eed37_0 | |||
| - font-ttf-ubuntu=0.83=h77eed37_3 | |||
| - fontconfig=2.15.0=h7e30c49_1 | |||
| - fonts-conda-ecosystem=1=0 | |||
| - fonts-conda-forge=1=0 | |||
| - freetype=2.12.1=h267a509_2 | |||
| - gettext=0.22.5=he02047a_3 | |||
| - gettext-tools=0.22.5=he02047a_3 | |||
| - giflib=5.2.2=hd590300_0 | |||
| - gmp=6.3.0=hac33072_2 | |||
| - gmpy2=2.1.5=py311h0f6cedb_3 | |||
| - gnutls=3.7.9=hb077bed_0 | |||
| - h2=4.1.0=pyhd8ed1ab_1 | |||
| - hpack=4.0.0=pyhd8ed1ab_1 | |||
| - hyperframe=6.0.1=pyhd8ed1ab_1 | |||
| - idna=3.10=pyhd8ed1ab_1 | |||
| - intel-openmp=2022.0.1=h06a4308_3633 | |||
| - jinja2=3.1.5=pyhd8ed1ab_0 | |||
| - lame=3.100=h166bdaf_1003 | |||
| - lcms2=2.16=hb7c19ff_0 | |||
| - ld_impl_linux-64=2.43=h712a8e2_2 | |||
| - lerc=4.0.0=h27087fc_0 | |||
| - libasprintf=0.22.5=he8f35ee_3 | |||
| - libasprintf-devel=0.22.5=he8f35ee_3 | |||
| - libblas=3.9.0=16_linux64_mkl | |||
| - libcblas=3.9.0=16_linux64_mkl | |||
| - libcublas=12.4.5.8=0 | |||
| - libcufft=11.2.1.3=0 | |||
| - libcufile=1.11.1.6=0 | |||
| - libcurand=10.3.7.77=0 | |||
| - libcusolver=11.6.1.9=0 | |||
| - libcusparse=12.3.1.170=0 | |||
| - libdeflate=1.23=h4ddbbb0_0 | |||
| - libdrm=2.4.124=hb9d3cd8_0 | |||
| - libegl=1.7.0=ha4b6fd6_2 | |||
| - libexpat=2.6.4=h5888daf_0 | |||
| - libffi=3.4.2=h7f98852_5 | |||
| - libgcc=14.2.0=h77fa898_1 | |||
| - libgcc-ng=14.2.0=h69a702a_1 | |||
| - libgettextpo=0.22.5=he02047a_3 | |||
| - libgettextpo-devel=0.22.5=he02047a_3 | |||
| - libgl=1.7.0=ha4b6fd6_2 | |||
| - libglvnd=1.7.0=ha4b6fd6_2 | |||
| - libglx=1.7.0=ha4b6fd6_2 | |||
| - libgomp=14.2.0=h77fa898_1 | |||
| - libiconv=1.17=hd590300_2 | |||
| - libidn2=2.3.7=hd590300_0 | |||
| - libjpeg-turbo=3.0.0=hd590300_1 | |||
| - liblapack=3.9.0=16_linux64_mkl | |||
| - liblzma=5.6.3=hb9d3cd8_1 | |||
| - libnpp=12.2.5.30=0 | |||
| - libnsl=2.0.1=hd590300_0 | |||
| - libnvfatbin=12.6.77=0 | |||
| - libnvjitlink=12.4.127=0 | |||
| - libnvjpeg=12.3.1.117=0 | |||
| - libpciaccess=0.18=hd590300_0 | |||
| - libpng=1.6.45=h943b412_0 | |||
| - libsqlite=3.47.2=hee588c1_0 | |||
| - libstdcxx=14.2.0=hc0a3c3a_1 | |||
| - libstdcxx-ng=14.2.0=h4852527_1 | |||
| - libtasn1=4.19.0=h166bdaf_0 | |||
| - libtiff=4.7.0=hd9ff511_3 | |||
| - libunistring=0.9.10=h7f98852_0 | |||
| - libuuid=2.38.1=h0b41bf4_0 | |||
| - libva=2.22.0=h8a09558_1 | |||
| - libvpx=1.13.1=h59595ed_0 | |||
| - libwebp=1.5.0=hae8dbeb_0 | |||
| - libwebp-base=1.5.0=h851e524_0 | |||
| - libxcb=1.17.0=h8a09558_0 | |||
| - libxcrypt=4.4.36=hd590300_1 | |||
| - libxml2=2.13.5=h0d44e9d_1 | |||
| - libzlib=1.3.1=hb9d3cd8_2 | |||
| - llvm-openmp=15.0.7=h0cdce71_0 | |||
| - markupsafe=3.0.2=py311h2dc5d0c_1 | |||
| - mkl=2022.1.0=hc2b9512_224 | |||
| - mpc=1.3.1=h24ddda3_1 | |||
| - mpfr=4.2.1=h90cbb55_3 | |||
| - mpmath=1.3.0=pyhd8ed1ab_1 | |||
| - ncurses=6.5=h2d0b736_2 | |||
| - nettle=3.9.1=h7ab15ed_0 | |||
| - networkx=3.4.2=pyh267e887_2 | |||
| - numpy=2.2.1=py311hf916aec_0 | |||
| - openh264=2.3.1=hcb278e6_2 | |||
| - openjpeg=2.5.3=h5fbd93e_0 | |||
| - openssl=3.4.0=h7b32b05_1 | |||
| - p11-kit=0.24.1=hc5aa10d_0 | |||
| - pillow=11.1.0=py311h1322bbf_0 | |||
| - pip=24.3.1=pyh8b19718_2 | |||
| - pthread-stubs=0.4=hb9d3cd8_1002 | |||
| - pycparser=2.22=pyh29332c3_1 | |||
| - pysocks=1.7.1=pyha55dd90_7 | |||
| - python=3.11.11=h9e4cc4f_1_cpython | |||
| - python_abi=3.11=5_cp311 | |||
| - pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0 | |||
| - pytorch-cuda=12.4=hc786d27_7 | |||
| - pytorch-mutex=1.0=cuda | |||
| - pyyaml=6.0.2=py311h9ecbd09_1 | |||
| - readline=8.2=h8228510_1 | |||
| - requests=2.32.3=pyhd8ed1ab_1 | |||
| - setuptools=75.8.0=pyhff2d567_0 | |||
| - svt-av1=1.4.1=hcb278e6_0 | |||
| - sympy=1.13.3=pyh2585a3b_105 | |||
| - tk=8.6.13=noxft_h4845f30_101 | |||
| - torchaudio=2.5.1=py311_cu124 | |||
| - torchtriton=3.1.0=py311 | |||
| - torchvision=0.20.1=py311_cu124 | |||
| - typing_extensions=4.12.2=pyha770c72_1 | |||
| - tzdata=2024b=hc8b5060_0 | |||
| - urllib3=2.3.0=pyhd8ed1ab_0 | |||
| - wayland=1.23.1=h3e06ad9_0 | |||
| - wayland-protocols=1.37=hd8ed1ab_0 | |||
| - wheel=0.45.1=pyhd8ed1ab_1 | |||
| - x264=1!164.3095=h166bdaf_2 | |||
| - x265=3.5=h924138e_3 | |||
| - xorg-libx11=1.8.10=h4f16b4b_1 | |||
| - xorg-libxau=1.0.12=hb9d3cd8_0 | |||
| - xorg-libxdmcp=1.1.5=hb9d3cd8_0 | |||
| - xorg-libxext=1.3.6=hb9d3cd8_0 | |||
| - xorg-libxfixes=6.0.1=hb9d3cd8_0 | |||
| - yaml=0.2.5=h7f98852_2 | |||
| - zstandard=0.23.0=py311hbc35293_1 | |||
| - zstd=1.5.6=ha6fb4c9_0 | |||
| - pip: | |||
| - torchmetrics>=1.4 # nicer MRR/Hits@K utilities | |||
| - lovely-tensors>=0.1.0 # tensor utilities | |||
| - lovely-numpy>=0.1.0 # numpy utilities | |||
| - einops==0.8.1 | |||
| - pykeen==1.11.1 | |||
| - hydra-core==1.3.2 | |||
| - tensorboard==2.19.0 | |||
| @@ -0,0 +1,6 @@ | |||
| from .pretty_logger import get_pretty_logger | |||
| from .sampling import DeterministicRandomSampler | |||
| from .utils import set_seed | |||
| from .tb_handler import TensorBoardHandler | |||
| from .params import CommonParams, TrainingParams | |||
| from .checkpoint_manager import CheckpointManager | |||
| @@ -0,0 +1,221 @@ | |||
| from __future__ import annotations | |||
| import re | |||
| from pathlib import Path | |||
| class CheckpointManager: | |||
| """ | |||
| A checkpoint manager that handles checkpoint directory creation and path resolution. | |||
| Features: | |||
| - Creates sequentially numbered checkpoint directories | |||
| - Supports two loading modes: by components or by full path | |||
| """ | |||
| def __init__( | |||
| self, | |||
| root_directory: str | Path, | |||
| run_name: str, | |||
| load_only: bool = False, | |||
| ) -> None: | |||
| """ | |||
| Initialize the checkpoint manager. | |||
| Args: | |||
| root_directory: Root directory for checkpoints | |||
| run_name: Name of the run (used as suffix for checkpoint directories) | |||
| """ | |||
| self.root_directory = Path(root_directory) | |||
| self.run_name = run_name | |||
| self.checkpoint_directory = None | |||
| if not load_only: | |||
| self.root_directory.mkdir(parents=True, exist_ok=True) | |||
| self.checkpoint_directory = self._create_checkpoint_directory() | |||
| else: | |||
| self.checkpoint_directory = "" | |||
| def _find_existing_directories(self) -> list[int]: | |||
| """ | |||
| Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number. | |||
| Returns: | |||
| List of existing sequence numbers | |||
| """ | |||
| pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$") | |||
| existing_numbers = [] | |||
| if self.root_directory.exists(): | |||
| for item in self.root_directory.iterdir(): | |||
| if item.is_dir(): | |||
| match = pattern.match(item.name) | |||
| if match: | |||
| existing_numbers.append(int(match.group(1))) | |||
| return sorted(existing_numbers) | |||
| def _create_checkpoint_directory(self) -> Path: | |||
| """ | |||
| Create a new checkpoint directory with the next sequential number. | |||
| Returns: | |||
| Path to the created checkpoint directory | |||
| """ | |||
| existing_numbers = self._find_existing_directories() | |||
| # Determine the next number | |||
| next_number = max(existing_numbers) + 1 if existing_numbers else 1 | |||
| # Create directory name with 3-digit zero-padded number | |||
| dir_name = f"{next_number:03d}_{self.run_name}" | |||
| checkpoint_dir = self.root_directory / dir_name | |||
| self.run_name = dir_name | |||
| # Create the directory | |||
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |||
| return checkpoint_dir | |||
| def get_checkpoint_directory(self) -> Path: | |||
| """ | |||
| Get the current checkpoint directory. | |||
| Returns: | |||
| Path to the checkpoint directory | |||
| """ | |||
| return self.checkpoint_directory | |||
| def get_model_fpath(self, model_path: str) -> Path: | |||
| """ | |||
| Get the full path to a model checkpoint file. | |||
| Args: | |||
| model_path: Either a tuple (model_id, iteration) or a string path | |||
| Returns: | |||
| Full path to the model checkpoint file | |||
| """ | |||
| try: | |||
| model_path = eval(model_path) | |||
| if isinstance(model_path, tuple): | |||
| model_id, iteration = model_path | |||
| checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}" | |||
| filename = f"model-{iteration}.pt" | |||
| return Path(f"{checkpoint_directory}/{filename}") | |||
| except (SyntaxError, NameError): | |||
| pass | |||
| if isinstance(model_path, str): | |||
| return Path(model_path) | |||
| msg = "model_path must be a tuple (model_id, iteration) or a string path" | |||
| raise ValueError(msg) | |||
| def get_model_path_by_args( | |||
| self, | |||
| model_id: str | None = None, | |||
| iteration: int | str | None = None, | |||
| full_path: str | Path | None = None, | |||
| ) -> Path: | |||
| """ | |||
| Get the path for loading a model checkpoint. | |||
| Two modes of operation: | |||
| 1. Component mode: Provide model_id and iteration to construct path | |||
| 2. Full path mode: Provide full_path directly | |||
| Args: | |||
| model_id: Model identifier (used in component mode) | |||
| iteration: Training iteration (used in component mode) | |||
| full_path: Full path to checkpoint (used in full path mode) | |||
| Returns: | |||
| Path to the checkpoint file | |||
| Raises: | |||
| ValueError: If neither component parameters nor full_path are provided, | |||
| or if both modes are attempted simultaneously | |||
| """ | |||
| # Check which mode we're operating in | |||
| component_mode = model_id is not None or iteration is not None | |||
| full_path_mode = full_path is not None | |||
| if component_mode and full_path_mode: | |||
| msg = ( | |||
| "Cannot use both component mode (model_id/iteration) " | |||
| "and full path mode simultaneously" | |||
| ) | |||
| raise ValueError(msg) | |||
| if not component_mode and not full_path_mode: | |||
| msg = "Must provide either (model_id and iteration) or full_path" | |||
| raise ValueError(msg) | |||
| if full_path_mode: | |||
| # Full path mode: return the path as-is | |||
| return Path(full_path) | |||
| # Component mode: construct path from checkpoint directory, model_id, and iteration | |||
| if model_id is None or iteration is None: | |||
| msg = "Both model_id and iteration must be provided in component mode" | |||
| raise ValueError(msg) | |||
| filename = f"{model_id}_iter_{iteration}.pt" | |||
| return self.checkpoint_directory / filename | |||
| def save_checkpoint_info(self, info: dict) -> None: | |||
| """ | |||
| Save checkpoint information to a JSON file in the checkpoint directory. | |||
| Args: | |||
| info: Dictionary containing checkpoint metadata | |||
| """ | |||
| import json | |||
| info_file = self.checkpoint_directory / "checkpoint_info.json" | |||
| with open(info_file, "w") as f: | |||
| json.dump(info, f, indent=2) | |||
| def load_checkpoint_info(self) -> dict: | |||
| """ | |||
| Load checkpoint information from the checkpoint directory. | |||
| Returns: | |||
| Dictionary containing checkpoint metadata | |||
| """ | |||
| import json | |||
| info_file = self.checkpoint_directory / "checkpoint_info.json" | |||
| if info_file.exists(): | |||
| with open(info_file) as f: | |||
| return json.load(f) | |||
| return {} | |||
| def __str__(self) -> str: | |||
| return ( | |||
| f"CheckpointManager(root='{self.root_directory}', " | |||
| f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')" | |||
| ) | |||
| def __repr__(self) -> str: | |||
| return self.__str__() | |||
| # Example usage: | |||
| if __name__ == "__main__": | |||
| import tempfile | |||
| # Create checkpoint manager with proper temp directory | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| ckpt_manager = CheckpointManager(temp_dir, "my_experiment") | |||
| print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}") | |||
| # Component mode - construct path from components | |||
| model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000) | |||
| print(f"Model path (component mode): {model_path}") | |||
| # Full path mode - use existing full path | |||
| full_path = "/some/other/path/model.pt" | |||
| model_path = ckpt_manager.get_model_path_by_args(full_path=full_path) | |||
| print(f"Model path (full path mode): {model_path}") | |||
| @@ -0,0 +1,45 @@ | |||
| from __future__ import annotations | |||
| from dataclasses import dataclass | |||
| @dataclass | |||
| class CommonParams: | |||
| """ | |||
| Common parameters for the application. | |||
| """ | |||
| model_name: str | |||
| project_name: str | |||
| run_name: str | |||
| save_dpath: str | |||
| save_every: int = 1000 | |||
| load_path: str | None = "" | |||
| log_dir: str = "./runs" | |||
| log_every: int = 10 | |||
| log_console_every: int = 100 | |||
| evaluate_only: bool = False | |||
| @dataclass | |||
| class TrainingParams: | |||
| """ | |||
| Parameters for training. | |||
| """ | |||
| lr: float = 0.001 | |||
| min_lr: float = 0.0001 | |||
| weight_decay: float = 0.0 | |||
| t0: int = 50 | |||
| lr_step: int = 10 | |||
| gamma: float = 1.0 | |||
| batch_size: int = 128 | |||
| num_workers: int = 0 | |||
| seed: int = 42 | |||
| num_train_steps: int = 100 | |||
| num_epochs: int = 100 | |||
| eval_every: int = 5 | |||
| validation_sample_size: int = 1000 | |||
| validation_batch_size: int = 128 | |||
| @@ -0,0 +1,417 @@ | |||
| # pretty_logger.py | |||
| import json | |||
| import logging | |||
| import logging.handlers | |||
| import sys | |||
| import threading | |||
| import time | |||
| from contextlib import ContextDecorator | |||
| from functools import wraps | |||
| # ANSI escape codes for colors | |||
| RESET = "\x1b[0m" | |||
| BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = ( | |||
| "\x1b[30m", | |||
| "\x1b[31m", | |||
| "\x1b[32m", | |||
| "\x1b[33m", | |||
| "\x1b[34m", | |||
| "\x1b[35m", | |||
| "\x1b[36m", | |||
| "\x1b[37m", | |||
| ) | |||
| _LEVEL_COLORS = { | |||
| logging.DEBUG: CYAN, | |||
| logging.INFO: GREEN, | |||
| logging.WARNING: YELLOW, | |||
| logging.ERROR: RED, | |||
| logging.CRITICAL: MAGENTA, | |||
| } | |||
| def _supports_color() -> bool: | |||
| """ | |||
| Detect whether the running terminal supports ANSI colors. | |||
| """ | |||
| if hasattr(sys.stdout, "isatty") and sys.stdout.isatty(): | |||
| platform = sys.platform | |||
| if platform == "win32": | |||
| # On Windows, modern terminals may support ANSI; assume yes for Windows 10+ | |||
| return True | |||
| return True | |||
| return False | |||
| class PrettyFormatter(logging.Formatter): | |||
| """ | |||
| A logging.Formatter that outputs colorized logs to the console | |||
| (if enabled) and supports JSON formatting (if requested). | |||
| """ | |||
| def __init__( | |||
| self, | |||
| fmt: str = None, | |||
| datefmt: str = "%Y-%m-%d %H:%M:%S", | |||
| use_colors: bool = True, | |||
| json_output: bool = False, | |||
| ): | |||
| super().__init__(fmt=fmt, datefmt=datefmt) | |||
| self.use_colors = use_colors and _supports_color() | |||
| self.json_output = json_output | |||
| def format(self, record: logging.LogRecord) -> str: | |||
| # If JSON output is requested, dump a dict of relevant fields | |||
| if self.json_output: | |||
| log_record = { | |||
| "timestamp": self.formatTime(record, self.datefmt), | |||
| "name": record.name, | |||
| "level": record.levelname, | |||
| "message": record.getMessage(), | |||
| } | |||
| # Include extra/contextual fields | |||
| for k, v in record.__dict__.get("extra_fields", {}).items(): | |||
| log_record[k] = v | |||
| if record.exc_info: | |||
| log_record["exception"] = self.formatException(record.exc_info) | |||
| return json.dumps(log_record, ensure_ascii=False) | |||
| # Otherwise, build a colorized/“pretty” line | |||
| level_color = _LEVEL_COLORS.get(record.levelno, WHITE) if self.use_colors else "" | |||
| reset = RESET if self.use_colors else "" | |||
| timestamp = self.formatTime(record, self.datefmt) | |||
| name = record.name | |||
| level = record.levelname | |||
| # Basic format: [timestamp] [LEVEL] [name]: message | |||
| base = f"[{timestamp}] [{level}] [{name}]: {record.getMessage()}" | |||
| # If there are any extra fields attached via LoggerAdapter, append them | |||
| extra_fields = record.__dict__.get("extra_fields", {}) | |||
| if extra_fields: | |||
| # key1=val1 key2=val2 … | |||
| extras_str = " ".join(f"{k}={v}" for k, v in extra_fields.items()) | |||
| base = f"{base} :: {extras_str}" | |||
| # If exception info is present, include traceback | |||
| if record.exc_info: | |||
| exc_text = self.formatException(record.exc_info) | |||
| base = f"{base}\n{exc_text}" | |||
| if self.use_colors: | |||
| return f"{level_color}{base}{reset}" | |||
| else: | |||
| return base | |||
| class JsonFormatter(PrettyFormatter): | |||
| """ | |||
| Shortcut to force JSON formatting (ignores color flags). | |||
| """ | |||
| def __init__(self, datefmt: str = "%Y-%m-%d %H:%M:%S"): | |||
| super().__init__(fmt=None, datefmt=datefmt, use_colors=False, json_output=True) | |||
| class ContextFilter(logging.Filter): | |||
| """ | |||
| Inject extra_fields from a LoggerAdapter into the LogRecord, so the Formatter can find them. | |||
| """ | |||
| def filter(self, record: logging.LogRecord) -> bool: | |||
| extra = getattr(record, "extra", None) | |||
| if isinstance(extra, dict): | |||
| record.extra_fields = extra | |||
| else: | |||
| record.extra_fields = {} | |||
| return True | |||
| class Timer(ContextDecorator): | |||
| """ | |||
| Context manager & decorator to measure execution time of a code block or function. | |||
| Usage as context manager: | |||
| with Timer(logger, "block name"): | |||
| # code … | |||
| Usage as decorator: | |||
| @Timer(logger, "function foo") | |||
| def foo(...): | |||
| ... | |||
| """ | |||
| def __init__(self, logger: logging.Logger, name: str = None, level: int = logging.INFO): | |||
| self.logger = logger | |||
| self.name = name or "" | |||
| self.level = level | |||
| def __enter__(self): | |||
| self.start_time = time.time() | |||
| return self | |||
| def __exit__(self, exc_type, exc, exc_tb): | |||
| elapsed = time.time() - self.start_time | |||
| label = f"[Timer:{self.name}]" if self.name else "[Timer]" | |||
| self.logger.log(self.level, f"{label} elapsed {elapsed:.4f} sec") | |||
| return False # do not suppress exceptions | |||
| def log_function(logger: logging.Logger, level: int = logging.DEBUG): | |||
| """ | |||
| Decorator factory to log function entry/exit and execution time. | |||
| Usage: | |||
| @log_function(logger, level=logging.INFO) | |||
| def my_func(...): | |||
| ... | |||
| """ | |||
| def decorator(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| logger.log(level, f"➤ Entering {func.__name__}()") | |||
| start = time.time() | |||
| try: | |||
| result = func(*args, **kwargs) | |||
| except Exception: | |||
| logger.exception(f"✖ Exception in {func.__name__}()") | |||
| raise | |||
| elapsed = time.time() - start | |||
| logger.log(level, f"✔ Exiting {func.__name__}(), elapsed {elapsed:.4f} sec") | |||
| return result | |||
| return wrapper | |||
| return decorator | |||
| class PrettyLogger: | |||
| """ | |||
| Encapsulates a logging.Logger with “pretty” formatting, colorized console output, | |||
| rotating file handler, JSON option, contextual fields, etc. | |||
| """ | |||
| _instances = {} | |||
| _lock = threading.Lock() | |||
| def __new__(cls, name: str = "root", **kwargs): | |||
| """ | |||
| Implements a simple singleton: multiple calls with the same name return the same instance. | |||
| """ | |||
| with cls._lock: | |||
| if name not in cls._instances: | |||
| instance = super().__new__(cls) | |||
| cls._instances[name] = instance | |||
| return cls._instances[name] | |||
| def __init__( | |||
| self, | |||
| name: str = "root", | |||
| level: int = logging.DEBUG, | |||
| log_to_console: bool = True, | |||
| console_level: int = logging.DEBUG, | |||
| log_to_file: bool = False, | |||
| log_file: str = "app.log", | |||
| file_level: int = logging.INFO, | |||
| max_bytes: int = 10 * 1024 * 1024, # 10 MB | |||
| backup_count: int = 5, | |||
| formatter: PrettyFormatter = None, | |||
| json_output: bool = False, | |||
| ): | |||
| """ | |||
| Initialize (or re‐configure) a PrettyLogger. | |||
| Args: | |||
| name: logger name. | |||
| level: root level for the logger (lowest level that will be processed). | |||
| log_to_console: whether to attach a console (StreamHandler). | |||
| console_level: level for console output. | |||
| log_to_file: whether to attach a rotating file handler. | |||
| log_file: path to the log file. | |||
| file_level: level for file output. | |||
| max_bytes: rotation threshold in bytes. | |||
| backup_count: number of backup files to keep. | |||
| formatter: an instance of PrettyFormatter. If None, a default is created. | |||
| json_output: if True, forces JSON formatting on all handlers. | |||
| """ | |||
| # Avoid re‐initializing if already done | |||
| logger = logging.getLogger(name) | |||
| if getattr(logger, "_pretty_logger_inited", False): | |||
| return | |||
| self.logger = logger | |||
| self.logger.setLevel(level) | |||
| self.logger.propagate = False # avoid double‐logging if root logger is configured elsewhere | |||
| # Create a default formatter if not supplied | |||
| if formatter is None: | |||
| fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||
| formatter = PrettyFormatter( | |||
| fmt=fmt, use_colors=not json_output, json_output=json_output | |||
| ) | |||
| # Add ContextFilter so extra_fields is always present | |||
| context_filter = ContextFilter() | |||
| self.logger.addFilter(context_filter) | |||
| # Console handler | |||
| if log_to_console: | |||
| ch = logging.StreamHandler(sys.stdout) | |||
| ch.setLevel(console_level) | |||
| ch.setFormatter(formatter) | |||
| self.logger.addHandler(ch) | |||
| # Rotating file handler | |||
| if log_to_file: | |||
| fh = logging.handlers.RotatingFileHandler( | |||
| filename=log_file, | |||
| maxBytes=max_bytes, | |||
| backupCount=backup_count, | |||
| encoding="utf-8", | |||
| ) | |||
| fh.setLevel(file_level) | |||
| # For file, typically we don’t colorize | |||
| file_formatter = ( | |||
| JsonFormatter(datefmt="%Y-%m-%d %H:%M:%S") | |||
| if json_output | |||
| else PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||
| ) | |||
| fh.setFormatter(file_formatter) | |||
| self.logger.addHandler(fh) | |||
| # Mark as initialized | |||
| setattr(self.logger, "_pretty_logger_inited", True) | |||
| def addHandler(self, handler: logging.Handler): | |||
| """ | |||
| Add a custom handler to the logger. | |||
| This can be used to add any logging.Handler instance. | |||
| """ | |||
| self.logger.addHandler(handler) | |||
| def bind(self, **kwargs): | |||
| """ | |||
| Return a LoggerAdapter that automatically injects extra fields into all log records. | |||
| Example: | |||
| ctx_logger = pretty_logger.bind(request_id=123, user="alice") | |||
| ctx_logger.info("Something happened") # will include request_id and user in the output | |||
| """ | |||
| return logging.LoggerAdapter(self.logger, {"extra": kwargs}) | |||
| def add_stream_handler( | |||
| self, stream=None, level: int = logging.DEBUG, formatter: PrettyFormatter = None | |||
| ): | |||
| """ | |||
| Add an additional stream handler (e.g. to stderr). | |||
| """ | |||
| handler = logging.StreamHandler(stream or sys.stdout) | |||
| handler.setLevel(level) | |||
| if formatter is None: | |||
| fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||
| formatter = PrettyFormatter(fmt=fmt) | |||
| handler.setFormatter(formatter) | |||
| self.logger.addHandler(handler) | |||
| return handler | |||
| def add_rotating_file_handler( | |||
| self, | |||
| log_file: str, | |||
| level: int = logging.INFO, | |||
| max_bytes: int = 10 * 1024 * 1024, | |||
| backup_count: int = 5, | |||
| json_output: bool = False, | |||
| ): | |||
| """ | |||
| Add another rotating file handler at runtime. | |||
| """ | |||
| fh = logging.handlers.RotatingFileHandler( | |||
| filename=log_file, | |||
| maxBytes=max_bytes, | |||
| backupCount=backup_count, | |||
| encoding="utf-8", | |||
| ) | |||
| fh.setLevel(level) | |||
| if json_output: | |||
| formatter = JsonFormatter() | |||
| else: | |||
| fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||
| formatter = PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||
| fh.setFormatter(formatter) | |||
| self.logger.addHandler(fh) | |||
| return fh | |||
| def remove_handler(self, handler: logging.Handler): | |||
| """ | |||
| Remove a handler from the logger. | |||
| """ | |||
| self.logger.removeHandler(handler) | |||
| def set_level(self, level: int): | |||
| """ | |||
| Change the logger’s root level at runtime. | |||
| """ | |||
| self.logger.setLevel(level) | |||
| def get_logger(self) -> logging.Logger: | |||
| return self.logger | |||
| # Convenience methods to expose common logging calls directly from this object: | |||
| def debug(self, msg, *args, **kwargs): | |||
| self.logger.debug(msg, *args, **kwargs) | |||
| def info(self, msg, *args, **kwargs): | |||
| self.logger.info(msg, *args, **kwargs) | |||
| def warning(self, msg, *args, **kwargs): | |||
| self.logger.warning(msg, *args, **kwargs) | |||
| def error(self, msg, *args, **kwargs): | |||
| self.logger.error(msg, *args, **kwargs) | |||
| def critical(self, msg, *args, **kwargs): | |||
| self.logger.critical(msg, *args, **kwargs) | |||
| def exception(self, msg, *args, exc_info=True, **kwargs): | |||
| """ | |||
| Log an error message with exception traceback. By default exc_info=True. | |||
| """ | |||
| self.logger.error(msg, *args, exc_info=exc_info, **kwargs) | |||
| # Convenience function for a module‐level “global” pretty logger: | |||
| _global_loggers = {} | |||
| _global_lock = threading.Lock() | |||
| def get_pretty_logger( | |||
| name: str = "root", | |||
| level: int = logging.DEBUG, | |||
| log_to_console: bool = True, | |||
| console_level: int = logging.DEBUG, | |||
| log_to_file: bool = False, | |||
| log_file: str = "app.log", | |||
| file_level: int = logging.INFO, | |||
| max_bytes: int = 10 * 1024 * 1024, | |||
| backup_count: int = 5, | |||
| json_output: bool = False, | |||
| ) -> PrettyLogger: | |||
| """ | |||
| Return a PrettyLogger instance for `name`, creating/configuring it on first use. | |||
| Subsequent calls with the same `name` will return the same logger (singleton‐style). | |||
| """ | |||
| with _global_lock: | |||
| if name not in _global_loggers: | |||
| pl = PrettyLogger( | |||
| name=name, | |||
| level=level, | |||
| log_to_console=log_to_console, | |||
| console_level=console_level, | |||
| log_to_file=log_to_file, | |||
| log_file=log_file, | |||
| file_level=file_level, | |||
| max_bytes=max_bytes, | |||
| backup_count=backup_count, | |||
| json_output=json_output, | |||
| ) | |||
| _global_loggers[name] = pl | |||
| return _global_loggers[name] | |||
| @@ -0,0 +1,66 @@ | |||
| import torch | |||
| import numpy as np | |||
| from torch.utils.data.sampler import Sampler | |||
| def negative_sampling( | |||
| pos: torch.Tensor, | |||
| num_entities: int, | |||
| num_negatives: int = 1, | |||
| mode: list = ["head", "tail"], | |||
| ) -> torch.Tensor: | |||
| pos = pos.repeat_interleave(num_negatives, dim=0) | |||
| if "head" in mode: | |||
| neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||
| pos[:, 0] = neg | |||
| if "tail" in mode: | |||
| neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||
| pos[:, 2] = neg | |||
| return pos | |||
| class DeterministicRandomSampler(Sampler): | |||
| """ | |||
| A deterministic random sampler that selects a fixed number of samples | |||
| in a reproducible manner using a given seed. | |||
| """ | |||
| def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None: | |||
| """ | |||
| Args: | |||
| dataset (Dataset): PyTorch dataset. | |||
| sample_size (int, optional): Number of samples to draw. If None, use full dataset. | |||
| seed (int): Seed for reproducibility. | |||
| """ | |||
| self.dataset_size = dataset_size | |||
| self.seed = seed | |||
| self.sample_size = sample_size if sample_size != 0 else self.dataset_size | |||
| # Ensure sample size is within dataset size | |||
| if self.sample_size > self.dataset_size: | |||
| msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}." | |||
| raise ValueError( | |||
| msg, | |||
| ) | |||
| self.indices = self._generate_deterministic_indices() | |||
| def _generate_deterministic_indices(self) -> list[int]: | |||
| """ | |||
| Generates a fixed random subset of indices using the given seed. | |||
| """ | |||
| rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility | |||
| all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices | |||
| return all_indices[: self.sample_size].tolist() # Select only the desired number of samples | |||
| def __iter__(self) -> iter: | |||
| """ | |||
| Yields the shuffled dataset indices. | |||
| """ | |||
| return iter(self.indices) | |||
| def __len__(self) -> int: | |||
| """ | |||
| Returns the total number of samples drawn. | |||
| """ | |||
| return self.sample_size | |||
| @@ -0,0 +1,21 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| from typing import TYPE_CHECKING | |||
| if TYPE_CHECKING: | |||
| from torch.utils.tensorboard import SummaryWriter | |||
| class TensorBoardHandler(logging.Handler): | |||
| """Convert each log record into a TensorBoard text summary.""" | |||
| def __init__(self, writer: SummaryWriter, tag: str = "logs"): | |||
| super().__init__(level=logging.INFO) | |||
| self.writer = writer | |||
| self.tag = tag | |||
| self._step = 0 # will be incremented every emit | |||
| def emit(self, record: logging.LogRecord) -> None: | |||
| self.writer.add_text(self.tag, self.format(record), global_step=self._step) | |||
| self._step += 1 | |||
| @@ -0,0 +1,12 @@ | |||
| from torch.optim.lr_scheduler import StepLR | |||
| class StepMinLR(StepLR): | |||
| def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1): | |||
| self.min_lr = min_lr | |||
| super().__init__(optimizer, step_size, gamma, last_epoch) | |||
| def get_lr(self): | |||
| # use StepLR's formula, then clamp | |||
| raw_lrs = super().get_lr() | |||
| return [max(lr, self.min_lr) for lr in raw_lrs] | |||
| @@ -0,0 +1,25 @@ | |||
| import random | |||
| from enum import Enum | |||
| import numpy as np | |||
| import torch | |||
| def set_seed(seed: int = 42) -> None: | |||
| """ | |||
| Set the random seed for reproducibility. | |||
| Args: | |||
| seed (int): The seed value to set for random number generation. | |||
| """ | |||
| torch.manual_seed(seed) | |||
| random.seed(seed) | |||
| np.random.default_rng(seed) | |||
| if torch.cuda.is_available(): | |||
| torch.cuda.manual_seed_all(seed) | |||
| class Target(str, Enum): # ↔ PyKEEN's LABEL_* constants | |||
| HEAD = "head" | |||
| TAIL = "tail" | |||
| RELATION = "relation" | |||
| @@ -0,0 +1,85 @@ | |||
| from __future__ import annotations | |||
| from abc import ABC, abstractmethod | |||
| from typing import TYPE_CHECKING | |||
| import torch | |||
| from torch import nn | |||
| if TYPE_CHECKING: | |||
| from torch.optim.lr_scheduler import LRScheduler | |||
| class ModelTrainerBase(ABC): | |||
| def __init__( | |||
| self, | |||
| model: nn.Module, | |||
| optimizer: torch.optim.Optimizer, | |||
| scheduler: LRScheduler | None = None, | |||
| device: torch.device = "cpu", | |||
| ) -> None: | |||
| self.model = model.to(device) | |||
| self.optimizer = optimizer | |||
| self.scheduler = scheduler | |||
| self.device = device | |||
| def save(self, save_dpath: str, step: int) -> None: | |||
| """ | |||
| Save the model and optimizer state to the specified directory. | |||
| Args: | |||
| save_dpath (str): The directory path where the model and optimizer state will be saved. | |||
| """ | |||
| data = { | |||
| "model": self.model.state_dict(), | |||
| "optimizer": self.optimizer.state_dict(), | |||
| "step": step, | |||
| } | |||
| if self.scheduler is not None: | |||
| data["scheduler"] = self.scheduler.state_dict() | |||
| torch.save( | |||
| data, | |||
| f"{save_dpath}/model-{step}.pt", | |||
| ) | |||
| def load(self, load_fpath: str) -> None: | |||
| """ | |||
| Load the model and optimizer state from the specified directory. | |||
| Args: | |||
| load_dpath (str): The directory path from which the model and | |||
| optimizer state will be loaded. | |||
| """ | |||
| data = torch.load(load_fpath, map_location="cpu") | |||
| model_state_dict = data["model"] | |||
| self.model.load_state_dict(model_state_dict) # move everything to the correct device | |||
| self.model.to(self.device) | |||
| optimizer_state_dict = data["optimizer"] | |||
| self.optimizer.load_state_dict(optimizer_state_dict) | |||
| if self.scheduler is not None: | |||
| scheduler_data = data.get("scheduler", {}) | |||
| self.scheduler.load_state_dict(scheduler_data) | |||
| return data["step"] | |||
| @abstractmethod | |||
| def _loss(self, batch: torch.Tensor) -> torch.Tensor: | |||
| """ | |||
| Compute the loss for a batch of data. | |||
| Args: | |||
| batch (torch.Tensor): A tensor containing a batch of data. | |||
| Returns: | |||
| torch.Tensor: The computed loss for the batch. | |||
| """ | |||
| ... | |||
| @abstractmethod | |||
| def train_step(self, batch: torch.Tensor) -> float: ... | |||
| @abstractmethod | |||
| def eval_step(self, batch: torch.Tensor) -> float: ... | |||
| @@ -0,0 +1,148 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from einops import repeat | |||
| from pykeen.sampling import BernoulliNegativeSampler | |||
| from pykeen.training import SLCWATrainingLoop | |||
| from pykeen.triples.instances import SLCWABatch | |||
| from torch.optim import Adam | |||
| from torch.optim.lr_scheduler import StepLR | |||
| from torch.utils.data import DataLoader | |||
| from metrics.ranking import simple_ranking | |||
| from models.translation.trans_e import TransE | |||
| from tools.params import TrainingParams | |||
| from training.model_trainers.model_trainer_base import ModelTrainerBase | |||
| class TransETrainer(ModelTrainerBase): | |||
| def __init__( | |||
| self, | |||
| model: TransE, | |||
| training_params: TrainingParams, | |||
| triples_factory: torch.Tensor, | |||
| mapped_triples: torch.Tensor, | |||
| device: torch.device, | |||
| margin: float = 1.0, | |||
| n_negative: int = 1, | |||
| loss_fn: str = "margin", | |||
| **kwargs, | |||
| ) -> None: | |||
| optimizer = Adam( | |||
| model.inner_model.parameters(), | |||
| lr=training_params.lr, | |||
| weight_decay=training_params.weight_decay, | |||
| ) | |||
| # scheduler = None # Placeholder for scheduler, can be implemented later | |||
| scheduler = StepLR( | |||
| optimizer, | |||
| step_size=training_params.lr_step, | |||
| gamma=training_params.gamma, | |||
| ) | |||
| super().__init__(model, optimizer, scheduler, device) | |||
| self.triples_factory = triples_factory | |||
| self.margin = margin | |||
| self.n_negative = n_negative | |||
| self.loss_fn = loss_fn | |||
| self.negative_sampler = BernoulliNegativeSampler( | |||
| mapped_triples=mapped_triples.to(device), | |||
| num_entities=model.num_entities, | |||
| num_relations=model.num_relations, | |||
| num_negs_per_pos=n_negative, | |||
| ).to(device) | |||
| self.training_loop = SLCWATrainingLoop( | |||
| model=model.inner_model, | |||
| triples_factory=triples_factory, | |||
| optimizer=optimizer, | |||
| negative_sampler=self.negative_sampler, | |||
| lr_scheduler=scheduler, | |||
| ) | |||
| def create_data_loader( | |||
| self, | |||
| triples_factory: torch.Tensor, | |||
| batch_size: int, | |||
| shuffle: bool = True, | |||
| ) -> DataLoader: | |||
| triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||
| return self.training_loop._create_training_data_loader( | |||
| triples_factory=triples_factory, | |||
| sampler=None, | |||
| batch_size=batch_size, | |||
| shuffle=shuffle, | |||
| drop_last=False, | |||
| ) | |||
| def _loss(self, pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor: | |||
| """ | |||
| Compute the loss based on the positive and negative scores. | |||
| Args: | |||
| pos_scores (torch.Tensor): Scores for positive triples. | |||
| neg_scores (torch.Tensor): Scores for negative triples. | |||
| Returns: | |||
| torch.Tensor: The computed loss. | |||
| """ | |||
| k = neg_scores.shape[1] | |||
| if self.loss_fn == "margin": | |||
| target = -torch.ones_like(neg_scores) # want s_pos < s_neg | |||
| loss = F.margin_ranking_loss( | |||
| repeat(pos_scores, "b -> b k", k=k), | |||
| neg_scores, | |||
| target, | |||
| margin=self.margin, | |||
| ) | |||
| return loss | |||
| if self.loss_fn == "adversarial": | |||
| # b = pos_scores.size(0) | |||
| # neg_scores = rearrange(neg_scores, "(b k) -> b k", b=b, k=k) | |||
| # importance weights (detach so gradients flow only through log-sigmoid terms) | |||
| w = F.softmax(self.adv_temp * neg_scores, dim=1).detach() | |||
| pos_term = -F.logsigmoid(self.margin - pos_scores) # (B,) | |||
| neg_term = -(w * F.logsigmoid(neg_scores - self.margin)).sum(dim=1) # (B,) | |||
| loss = (pos_term + neg_term).mean() | |||
| return loss | |||
| return None | |||
| def train_step(self, batch: SLCWABatch, step: int) -> float: | |||
| """ | |||
| Perform a training step on the given batch of data. | |||
| Args: | |||
| batch (torch.Tensor): A tensor containing a batch of triples. | |||
| Returns: | |||
| float: The computed loss for the batch. | |||
| """ | |||
| continue_training = step > 0 | |||
| loss = self.training_loop.train( | |||
| triples_factory=self.triples_factory, | |||
| num_epochs=step + 1, | |||
| batch_size=self.training_loop._get_batch_size(batch), | |||
| continue_training=continue_training, | |||
| use_tqdm=False, | |||
| use_tqdm_batch=False, | |||
| label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||
| # num_workers=3, | |||
| )[-1] | |||
| return loss | |||
| def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
| """Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||
| self.model.eval() | |||
| with torch.no_grad(): | |||
| return simple_ranking(self.model, batch.to(self.device)) | |||
| @@ -0,0 +1,121 @@ | |||
| import torch | |||
| from pykeen.losses import MarginRankingLoss | |||
| from pykeen.sampling import BernoulliNegativeSampler | |||
| from pykeen.training import SLCWATrainingLoop | |||
| from pykeen.triples.instances import SLCWABatch | |||
| from torch.optim import Adam | |||
| from metrics.ranking import simple_ranking | |||
| from models.translation.trans_h import TransH | |||
| from tools.params import TrainingParams | |||
| from tools.train import StepMinLR | |||
| from training.model_trainers.model_trainer_base import ModelTrainerBase | |||
| class TransHTrainer(ModelTrainerBase): | |||
| def __init__( | |||
| self, | |||
| model: TransH, | |||
| training_params: TrainingParams, | |||
| triples_factory: torch.Tensor, | |||
| mapped_triples: torch.Tensor, | |||
| device: torch.device, | |||
| margin: float = 1.0, | |||
| n_negative: int = 1, | |||
| regul_rate: float = 0.0, | |||
| loss_fn: str = "margin", | |||
| **kwargs, | |||
| ) -> None: | |||
| optimizer = Adam( | |||
| model.inner_model.parameters(), | |||
| lr=training_params.lr, | |||
| weight_decay=training_params.weight_decay, | |||
| ) | |||
| scheduler = StepMinLR( | |||
| optimizer, | |||
| step_size=training_params.lr_step, | |||
| gamma=training_params.gamma, | |||
| min_lr=training_params.min_lr, | |||
| ) | |||
| super().__init__(model, optimizer, scheduler, device) | |||
| self.triples_factory = triples_factory | |||
| self.margin = margin | |||
| self.n_negative = n_negative | |||
| self.regul_rate = regul_rate | |||
| self.loss_fn = loss_fn | |||
| self.negative_sampler = BernoulliNegativeSampler( | |||
| mapped_triples=mapped_triples.to(device), | |||
| num_entities=model.num_entities, | |||
| num_relations=model.num_relations, | |||
| num_negs_per_pos=n_negative, | |||
| ).to(device) | |||
| self.training_loop = SLCWATrainingLoop( | |||
| model=model.inner_model, | |||
| triples_factory=triples_factory, | |||
| optimizer=optimizer, | |||
| negative_sampler=self.negative_sampler, | |||
| lr_scheduler=scheduler, | |||
| ) | |||
| def _loss(self) -> torch.Tensor: | |||
| return self.model.loss | |||
| def create_data_loader( | |||
| self, | |||
| triples_factory: torch.Tensor, | |||
| batch_size: int, | |||
| shuffle: bool = True, | |||
| ) -> torch.utils.data.DataLoader: | |||
| triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||
| return self.training_loop._create_training_data_loader( | |||
| triples_factory=triples_factory, | |||
| sampler=None, | |||
| batch_size=batch_size, | |||
| shuffle=shuffle, | |||
| drop_last=False, | |||
| ) | |||
| @property | |||
| def loss(self) -> torch.Tensor: | |||
| return MarginRankingLoss(margin=self.margin, reduction="mean") | |||
| def train_step( | |||
| self, | |||
| batch: SLCWABatch, | |||
| step: int, | |||
| ) -> float: | |||
| """ | |||
| Perform a training step on the given batch of data. | |||
| Args: | |||
| batch (torch.Tensor): A tensor containing a batch of triples. | |||
| Returns: | |||
| float: The computed loss for the batch. | |||
| """ | |||
| continue_training = step > 0 | |||
| loss = self.training_loop.train( | |||
| triples_factory=self.triples_factory, | |||
| num_epochs=step + 1, | |||
| batch_size=self.training_loop._get_batch_size(batch), | |||
| continue_training=continue_training, | |||
| use_tqdm=False, | |||
| use_tqdm_batch=False, | |||
| label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||
| )[-1] | |||
| return loss | |||
| def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
| """Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||
| self.model.eval() | |||
| with torch.no_grad(): | |||
| return simple_ranking(self.model, batch.to(self.device)) | |||
| @@ -0,0 +1,116 @@ | |||
| import torch | |||
| from pykeen.sampling import BernoulliNegativeSampler | |||
| from pykeen.training import SLCWATrainingLoop | |||
| from pykeen.triples.instances import SLCWABatch | |||
| from torch.optim import Adam | |||
| from metrics.ranking import simple_ranking | |||
| from models.translation.trans_r import TransR | |||
| from tools.params import TrainingParams | |||
| from tools.train import StepMinLR | |||
| from training.model_trainers.model_trainer_base import ModelTrainerBase | |||
| class TransRTrainer(ModelTrainerBase): | |||
| def __init__( | |||
| self, | |||
| model: TransR, | |||
| training_params: TrainingParams, | |||
| triples_factory: torch.Tensor, | |||
| mapped_triples: torch.Tensor, | |||
| device: torch.device, | |||
| margin: float = 1.0, | |||
| n_negative: int = 1, | |||
| regul_rate: float = 0.0, | |||
| loss_fn: str = "margin", | |||
| **kwargs, | |||
| ) -> None: | |||
| optimizer = Adam( | |||
| model.inner_model.parameters(), | |||
| lr=training_params.lr, | |||
| weight_decay=training_params.weight_decay, | |||
| ) | |||
| scheduler = StepMinLR( | |||
| optimizer, | |||
| step_size=training_params.lr_step, | |||
| gamma=training_params.gamma, | |||
| min_lr=training_params.min_lr, | |||
| ) | |||
| super().__init__(model, optimizer, scheduler, device) | |||
| self.triples_factory = triples_factory | |||
| self.margin = margin | |||
| self.n_negative = n_negative | |||
| self.regul_rate = regul_rate | |||
| self.loss_fn = loss_fn | |||
| self.negative_sampler = BernoulliNegativeSampler( | |||
| mapped_triples=mapped_triples.to(device), | |||
| num_entities=model.num_entities, | |||
| num_relations=model.num_relations, | |||
| num_negs_per_pos=n_negative, | |||
| ).to(device) | |||
| self.training_loop = SLCWATrainingLoop( | |||
| model=model.inner_model, | |||
| triples_factory=triples_factory, | |||
| optimizer=optimizer, | |||
| negative_sampler=self.negative_sampler, | |||
| lr_scheduler=scheduler, | |||
| ) | |||
| def create_data_loader( | |||
| self, | |||
| triples_factory: torch.Tensor, | |||
| batch_size: int, | |||
| shuffle: bool = True, | |||
| ) -> torch.utils.data.DataLoader: | |||
| triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||
| return self.training_loop._create_training_data_loader( | |||
| triples_factory=triples_factory, | |||
| sampler=None, | |||
| batch_size=batch_size, | |||
| shuffle=shuffle, | |||
| drop_last=False, | |||
| ) | |||
| def _loss(self) -> torch.Tensor: | |||
| return self.model.loss | |||
| def train_step( | |||
| self, | |||
| batch: SLCWABatch, | |||
| step: int, | |||
| ) -> float: | |||
| """ | |||
| Perform a training step on the given batch of data. | |||
| Args: | |||
| batch (torch.Tensor): A tensor containing a batch of triples. | |||
| Returns: | |||
| float: The computed loss for the batch. | |||
| """ | |||
| continue_training = step > 0 | |||
| loss = self.training_loop.train( | |||
| triples_factory=self.triples_factory, | |||
| num_epochs=step + 1, | |||
| batch_size=self.training_loop._get_batch_size(batch), | |||
| continue_training=continue_training, | |||
| use_tqdm=False, | |||
| use_tqdm_batch=False, | |||
| label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||
| )[-1] | |||
| return loss | |||
| def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
| """Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||
| self.model.eval() | |||
| with torch.no_grad(): | |||
| return simple_ranking(self.model, batch.to(self.device)) | |||
| @@ -0,0 +1,214 @@ | |||
| from __future__ import annotations | |||
| from pathlib import Path | |||
| from typing import TYPE_CHECKING, Any | |||
| import torch | |||
| from pykeen.evaluation import RankBasedEvaluator | |||
| from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank | |||
| from torch.utils.data import DataLoader | |||
| from torch.utils.tensorboard import SummaryWriter | |||
| from tools import ( | |||
| CheckpointManager, | |||
| DeterministicRandomSampler, | |||
| TensorBoardHandler, | |||
| get_pretty_logger, | |||
| set_seed, | |||
| ) | |||
| if TYPE_CHECKING: | |||
| from data.kg_dataset import KGDataset | |||
| from tools import CommonParams, TrainingParams | |||
| from .model_trainers.model_trainer_base import ModelTrainerBase | |||
| logger = get_pretty_logger(__name__) | |||
| cpu_device = torch.device("cpu") | |||
| class Trainer: | |||
| def __init__( | |||
| self, | |||
| train_dataset: KGDataset, | |||
| val_dataset: KGDataset | None, | |||
| model_trainer: ModelTrainerBase, | |||
| common_params: CommonParams, | |||
| training_params: TrainingParams, | |||
| device: torch.device = cpu_device, | |||
| ) -> None: | |||
| set_seed(training_params.seed) | |||
| self.train_dataset = train_dataset | |||
| self.val_dataset = val_dataset if val_dataset is not None else train_dataset | |||
| self.model_trainer = model_trainer | |||
| self.common_params = common_params | |||
| self.training_params = training_params | |||
| self.device = device | |||
| self.train_loader = self.model_trainer.create_data_loader( | |||
| triples_factory=self.train_dataset.triples_factory, | |||
| batch_size=training_params.batch_size, | |||
| shuffle=True, | |||
| ) | |||
| self.train_iterator = self._data_iterator(self.train_loader) | |||
| seed = 1234 | |||
| val_sampler = DeterministicRandomSampler( | |||
| len(self.val_dataset), | |||
| sample_size=self.training_params.validation_sample_size, | |||
| seed=seed, | |||
| ) | |||
| self.valid_loader = DataLoader( | |||
| self.val_dataset, | |||
| batch_size=training_params.validation_batch_size, | |||
| shuffle=False, | |||
| num_workers=training_params.num_workers, | |||
| sampler=val_sampler, | |||
| ) | |||
| self.valid_iterator = self._data_iterator(self.valid_loader) | |||
| self.step = 0 | |||
| self.log_every = common_params.log_every | |||
| self.log_console_every = common_params.log_console_every | |||
| self.save_dpath = common_params.save_dpath | |||
| self.save_every = common_params.save_every | |||
| self.checkpoint_manager = CheckpointManager( | |||
| root_directory=self.save_dpath, | |||
| run_name=common_params.run_name, | |||
| load_only=common_params.evaluate_only, | |||
| ) | |||
| self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory | |||
| if not common_params.evaluate_only: | |||
| logger.info(f"Saving checkpoints to {self.ckpt_dpath}") | |||
| if self.common_params.load_path: | |||
| self.load() | |||
| self.writer = None | |||
| if not common_params.evaluate_only: | |||
| run_name = self.checkpoint_manager.run_name | |||
| self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}") | |||
| logger.addHandler(TensorBoardHandler(self.writer)) | |||
| def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None: | |||
| """ | |||
| Log a message to the console and TensorBoard. | |||
| Args: | |||
| message (str): The message to log. | |||
| """ | |||
| if console_msg is not None: | |||
| logger.info(f"{console_msg}") | |||
| if self.writer is not None: | |||
| self.writer.add_scalar(tag, log_value, self.step) | |||
| def save(self) -> None: | |||
| save_dpath = self.ckpt_dpath | |||
| save_dpath = Path(save_dpath) | |||
| save_dpath.mkdir(parents=True, exist_ok=True) | |||
| self.model_trainer.save(save_dpath, self.step) | |||
| def load(self) -> None: | |||
| load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path) | |||
| load_fpath = Path(load_path) | |||
| if load_fpath.exists(): | |||
| self.step = self.model_trainer.load(load_fpath) | |||
| logger.info(f"Model loaded from {load_fpath}.") | |||
| else: | |||
| msg = f"Load path {load_fpath} does not exist." | |||
| raise FileNotFoundError(msg) | |||
| def _data_iterator(self, loader: DataLoader) -> iter: | |||
| """Yield batches of positive and negative triples.""" | |||
| while True: | |||
| yield from loader | |||
| def train(self) -> None: | |||
| """Run `n_steps` optimisation steps; return mean loss.""" | |||
| data_batch = next(self.train_iterator) | |||
| loss = self.model_trainer.train_step(data_batch, self.step) | |||
| if self.step % self.log_console_every == 0: | |||
| logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}") | |||
| if self.step % self.log_every == 0: | |||
| self.log("train/loss", loss) | |||
| self.step += 1 | |||
| def evaluate(self) -> dict[str, torch.Tensor]: | |||
| ks = [1, 3, 10] | |||
| metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks] | |||
| evaluator = RankBasedEvaluator( | |||
| metrics=metrics, | |||
| filtered=True, | |||
| batch_size=self.training_params.validation_batch_size, | |||
| ) | |||
| result = evaluator.evaluate( | |||
| model=self.model_trainer.model, | |||
| mapped_triples=self.val_dataset.triples, | |||
| additional_filter_triples=[ | |||
| self.train_dataset.triples, | |||
| self.val_dataset.triples, | |||
| ], | |||
| batch_size=self.training_params.validation_batch_size, | |||
| ) | |||
| mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank") | |||
| hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks} | |||
| metrics = { | |||
| "mrr": mrr, | |||
| **hits_at, | |||
| } | |||
| for k, v in metrics.items(): | |||
| self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}") | |||
| # ───────────────────────── # | |||
| # master schedule | |||
| # ───────────────────────── # | |||
| def run(self) -> None: | |||
| """ | |||
| Alternate training and evaluation until `total_steps` | |||
| optimisation steps have been executed. | |||
| """ | |||
| total_steps = self.training_params.num_train_steps | |||
| eval_every = self.training_params.eval_every | |||
| while self.step < total_steps: | |||
| self.train() | |||
| if self.step % eval_every == 0: | |||
| logger.info(f"Evaluating at step {self.step}...") | |||
| self.evaluate() | |||
| if self.step % self.save_every == 0: | |||
| logger.info(f"Saving model at step {self.step}...") | |||
| self.save() | |||
| logger.info("Training complete.") | |||
| def reset(self) -> None: | |||
| """ | |||
| Reset the trainer state, including the model trainer and data iterators. | |||
| """ | |||
| self.model_trainer.reset() | |||
| self.train_iterator = self._data_iterator(self.train_loader) | |||
| self.valid_iterator = self._data_iterator(self.valid_loader) | |||
| self.step = 0 | |||
| logger.info("Trainer state has been reset.") | |||