| # KGEvaluation |
| from __future__ import annotations | |||||
| import argparse | |||||
| import logging | |||||
| import torch | |||||
| from data import YAGO310Dataset | |||||
| # from metrics.crec_modifier import tune_crec_parallel_edits | |||||
| from data.kg_dataset import KGDataset | |||||
| from metrics.crec_radius_sample import find_sample | |||||
| from metrics.wlcrec import WLCREC | |||||
| from tools import get_pretty_logger | |||||
| logging = get_pretty_logger(__name__) | |||||
| # -------------------------------------------------------------------- | |||||
| # CLI | |||||
| # -------------------------------------------------------------------- | |||||
| def main(): | |||||
| p = argparse.ArgumentParser() | |||||
| p.add_argument("--depth", type=int, default=5) | |||||
| p.add_argument("--lower", type=float, default=0.25) | |||||
| p.add_argument("--upper", type=float, default=0.26) | |||||
| p.add_argument("--seed", type=int, default=0) | |||||
| p.add_argument("--save", type=str, default="fb15k_crec_bi.pt") | |||||
| args = p.parse_args() | |||||
| # ds = FB15KDataset() | |||||
| ds = YAGO310Dataset() | |||||
| triples = ds.triples | |||||
| n_ent, n_rel = ds.num_entities, ds.num_relations | |||||
| logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples)) | |||||
| # wlcrec = WLCREC(ds) | |||||
| # c = wlcrec.compute(5)[-2] # C_NWLEC | |||||
| # colours = wlcrec.wl_colours(args.depth)[-1] | |||||
| # tuned = tune_crec( | |||||
| # wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed | |||||
| # ) | |||||
| # tuned = tune_crec_parallel_edits( | |||||
| # triples, | |||||
| # n_ent, | |||||
| # n_rel, | |||||
| # depth=args.depth, | |||||
| # lo=args.lower, | |||||
| # hi=args.upper, | |||||
| # seed=args.seed, | |||||
| # ) | |||||
| # tuned = fast_tune_crec( | |||||
| # triples.tolist(), | |||||
| # colours, | |||||
| # n_ent, | |||||
| # n_rel, | |||||
| # depth=args.depth, | |||||
| # c=c, | |||||
| # lo=args.lower, | |||||
| # hi=args.upper, | |||||
| # max_iters=1000, | |||||
| # max_workers=10, # Adjust number of workers as needed | |||||
| # ) | |||||
| tuned = find_sample( | |||||
| triples.tolist(), | |||||
| n_ent, | |||||
| n_rel, | |||||
| depth=args.depth, | |||||
| ratio=0.1, # Adjust ratio as needed | |||||
| radius=2, # Adjust radius as needed | |||||
| lower=args.lower, | |||||
| upper=args.upper, | |||||
| max_tries=1000, | |||||
| seed=58, | |||||
| ) | |||||
| tuned_ds = KGDataset( | |||||
| triples_factory=None, | |||||
| triples=torch.tensor(tuned, dtype=torch.long), | |||||
| num_entities=n_ent, | |||||
| num_relations=n_rel, | |||||
| split="train", | |||||
| ) | |||||
| tuned_wlcrec = WLCREC(tuned_ds) | |||||
| entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5) | |||||
| print( | |||||
| f"\nTuned CREC results (H={5})", | |||||
| f"\n • avg WLEC : {entropy:.6f} nats", | |||||
| f"\n • C_ratio : {c_ratio:.6f}", | |||||
| f"\n • C_NWLEC : {c_nwlec:.6f}", | |||||
| f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||||
| f"\n • D_ratio : {d_ratio:.6f}", | |||||
| f"\n • D_NWLEC : {d_nwlec:.6f}", | |||||
| ) | |||||
| torch.save(torch.tensor(tuned, dtype=torch.long), args.save) | |||||
| logging.info("Saved tuned triples → %s", args.save) | |||||
| if __name__ == "__main__": | |||||
| main() |
| _target_: tools.params.CommonParams | |||||
| model_name: "default_model" | |||||
| project_name: "my_project" | |||||
| run_name: ... | |||||
| # Where to write all outputs (checkpoints / tensorboard logs / etc.) | |||||
| save_dpath: "./checkpoints" | |||||
| save_every: 200 | |||||
| # If you want to resume from a checkpoint, set load_dpath to a string path; | |||||
| # otherwise leave it null or empty. | |||||
| load_path: null | |||||
| # Default log directory | |||||
| log_dir: "logs" | |||||
| log_every: 10 | |||||
| log_console_every: 10 | |||||
| evaluate_only: false |
| defaults: | |||||
| - common: common | |||||
| - model: model | |||||
| - data: data | |||||
| - training: training | |||||
| hydra: | |||||
| run: | |||||
| dir: "./outputs" |
| _target_: data.wn18rr_dataset.WN18RRDataset | |||||
| split: train |
| train: | |||||
| _target_: data.fb15k.FB15KDataset | |||||
| split: train | |||||
| valid: | |||||
| _target_: data.fb15k.FB15KDataset | |||||
| split: valid |
| train: | |||||
| _target_: data.wn18.WN18Dataset | |||||
| split: train | |||||
| valid: | |||||
| _target_: data.wn18.WN18Dataset | |||||
| split: valid |
| train: | |||||
| _target_: data.wn18.WN18RRDataset | |||||
| split: train | |||||
| valid: | |||||
| _target_: data.wn18.WN18RRDataset | |||||
| split: valid |
| train: | |||||
| _target_: data.yago3_10.YAGO310Dataset | |||||
| split: train | |||||
| valid: | |||||
| _target_: data.yago3_10.YAGO310Dataset | |||||
| split: valid | |||||
| num_entities: 100 | |||||
| num_relations: 100 | |||||
| dim: 100 |
| defaults: | |||||
| - model | |||||
| - _self_ | |||||
| _target_: models.translation.trans_e.TransE | |||||
| dim: 200 | |||||
| sf_norm: 1 | |||||
| p_norm: true |
| defaults: | |||||
| - model | |||||
| - _self_ | |||||
| _target_: models.translation.trans_h.TransH | |||||
| dim: 200 | |||||
| sf_norm: 1 | |||||
| p_norm: true | |||||
| p_norm_value: 2 | |||||
| margin: 1.0 |
| defaults: | |||||
| - model | |||||
| - _self_ | |||||
| _target_: models.translation.trans_r.TransR | |||||
| entity_dim: 200 | |||||
| relation_dim: 30 | |||||
| sf_norm: 1 | |||||
| p_norm: true | |||||
| p_norm_value: 2 | |||||
| margin: 1.0 |
| _target_: tools.params.TrainingParams | |||||
| lr: 0.005 | |||||
| min_lr: 0.0001 | |||||
| weight_decay: 0.0 | |||||
| t0: 50 | |||||
| lr_step: 10 | |||||
| gamma: 0.95 | |||||
| batch_size: 8192 | |||||
| num_workers: 3 | |||||
| seed: 42 | |||||
| num_train_steps: 10000 | |||||
| eval_every: 100 | |||||
| validation_sample_size: 1000 | |||||
| validation_batch_size: 64 | |||||
| model_trainer: | |||||
| _target_: training.model_trainers.model_trainer_base.ModelTrainerBase |
| defaults: | |||||
| - training | |||||
| - _self_ | |||||
| model_trainer: | |||||
| _target_: training.model_trainers.translation.trans_e_trainer.TransETrainer | |||||
| margin: 1.0 | |||||
| n_negative: 5 | |||||
| regul_rate: 0.0 | |||||
| loss_fn: "margin" |
| defaults: | |||||
| - training | |||||
| - _self_ | |||||
| model_trainer: | |||||
| _target_: training.model_trainers.translation.trans_r_trainer.TransRTrainer | |||||
| margin: 1.0 | |||||
| n_negative: 10 | |||||
| regul_rate: 0.0 | |||||
| loss_fn: "margin" |
| defaults: | |||||
| - config | |||||
| - override common: common | |||||
| - override data: fb15k | |||||
| - override model: trans_e | |||||
| - override training: trans_e_trainer | |||||
| - _self_ | |||||
| training: | |||||
| model_trainer: | |||||
| loss_fn: "margin" | |||||
| common: | |||||
| run_name: "TransE_FB15K" | |||||
| evaluate_only: true | |||||
| load_path: (2, 1800) |
| defaults: | |||||
| - config | |||||
| - override common: common | |||||
| - override data: wn18 | |||||
| - override model: trans_e | |||||
| - override training: trans_e_trainer | |||||
| - _self_ | |||||
| training: | |||||
| model_trainer: | |||||
| loss_fn: "margin" | |||||
| common: | |||||
| run_name: "TransE_WN18" | |||||
| evaluate_only: true | |||||
| load_path: (82, 3000) |
| defaults: | |||||
| - config | |||||
| - override common: common | |||||
| - override data: wn18rr | |||||
| - override model: trans_e | |||||
| - override training: trans_e_trainer | |||||
| - _self_ | |||||
| training: | |||||
| model_trainer: | |||||
| loss_fn: "margin" | |||||
| common: | |||||
| run_name: "TransE_WN18rr" | |||||
| evaluate_only: true | |||||
| load_path: (2, 6000) |
| defaults: | |||||
| - config | |||||
| - override common: common | |||||
| - override data: yago3_10 | |||||
| - override model: trans_e | |||||
| - override training: trans_e_trainer | |||||
| - _self_ | |||||
| training: | |||||
| model_trainer: | |||||
| loss_fn: "margin" | |||||
| common: | |||||
| run_name: "TransE_YAGO310" | |||||
| evaluate_only: true | |||||
| load_path: (1, 2600) |
| defaults: | |||||
| - config | |||||
| - override common: common | |||||
| - override data: wn18 | |||||
| - override model: trans_r | |||||
| - override training: trans_r_trainer | |||||
| - _self_ | |||||
| training: | |||||
| model_trainer: | |||||
| loss_fn: "margin" | |||||
| common: | |||||
| run_name: "TransR_WN18" | |||||
| # evaluate_only: true | |||||
| # load_path: (1, 1000) |
| from .fb15k import FB15KDataset | |||||
| from .wn18 import WN18Dataset, WN18RRDataset | |||||
| from .yago3_10 import YAGO310Dataset | |||||
| from .hationet import HetionetDataset | |||||
| from .open_bio_link import OpenBioLinkDataset | |||||
| from .openke_wiki import OpenKEWikiDataset |
| from typing import List, Tuple | |||||
| import torch | |||||
| from pykeen.datasets import FB15k | |||||
| from .kg_dataset import KGDataset | |||||
| class FB15KDataset(KGDataset): | |||||
| """A wrapper around PyKeen's FB15k dataset producing KGDataset instances.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| dataset = FB15k() | |||||
| # if split == "train": | |||||
| if True: | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'" | |||||
| raise ValueError(msg) | |||||
| custom_triples = torch.load( | |||||
| "/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt", | |||||
| map_location=torch.device("cpu"), | |||||
| ).to(torch.long) | |||||
| # .tolist() | |||||
| tf = tf.clone_and_exchange_triples( | |||||
| custom_triples, | |||||
| ) | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=dataset.num_entities, | |||||
| num_relations=dataset.num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| return torch.tensor(super().__getitem__(idx), dtype=torch.long) |
| from typing import List, Tuple | |||||
| import torch | |||||
| from pykeen.datasets import Hetionet | |||||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
| class HetionetDataset(KGDataset): | |||||
| """A wrapper around PyKeen's Hetionet dataset that produces KGDataset instances.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| """ | |||||
| Initialize the WN18RR wrapper. | |||||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||||
| """ | |||||
| # Load the PyKeen WN18 dataset | |||||
| dataset = Hetionet() | |||||
| # Select the appropriate TriplesFactory based on the requested split | |||||
| if split == "train": | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
| # Convert each row into a Python tuple of ints | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| # Get the total number of entities and relations in the entire dataset | |||||
| num_entities = dataset.num_entities | |||||
| num_relations = dataset.num_relations | |||||
| # Initialize the base KGDataset with the selected triples | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=num_entities, | |||||
| num_relations=num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| data = super().__getitem__(idx) | |||||
| # Convert the tuple to a torch tensor | |||||
| return torch.tensor(data, dtype=torch.long) |
| """PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||||
| from pathlib import Path | |||||
| from typing import Dict, List, Tuple | |||||
| import torch | |||||
| from pykeen.typing import MappedTriples | |||||
| from torch.utils.data import Dataset | |||||
| class KGDataset(Dataset): | |||||
| """PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||||
| def __init__( | |||||
| self, | |||||
| triples_factory: MappedTriples, | |||||
| triples: List[Tuple[int, int, int]], | |||||
| num_entities: int, | |||||
| num_relations: int, | |||||
| split: str = "train", | |||||
| ) -> None: | |||||
| super().__init__() | |||||
| self.triples_factory = triples_factory | |||||
| self.triples = torch.tensor(triples, dtype=torch.long) | |||||
| self.num_entities = num_entities | |||||
| self.num_relations = num_relations | |||||
| self.split = split | |||||
| def __len__(self) -> int: | |||||
| """Return the number of triples in the dataset.""" | |||||
| return len(self.triples) | |||||
| def __getitem__(self, idx: int) -> Tuple[int, int, int]: | |||||
| """Return the indexed triple at the given index.""" | |||||
| return self.triples[idx] | |||||
| def entities(self) -> torch.Tensor: | |||||
| """Return a list of all entities in the dataset.""" | |||||
| return torch.arange(self.num_entities) | |||||
| def relations(self) -> torch.Tensor: | |||||
| """Return a list of all relations in the dataset.""" | |||||
| return torch.arange(self.num_relations) | |||||
| def __repr__(self) -> str: | |||||
| return ( | |||||
| f"{self.__class__.__name__}(split={self.split!r}, " | |||||
| f"num_triples={len(self.triples)}, " | |||||
| f"num_entities={self.num_entities}, " | |||||
| f"num_relations={self.num_relations})" | |||||
| ) | |||||
| # ---------------------------- Helper functions -------------------------------- | |||||
| def _map_to_ids( | |||||
| triples: List[Tuple[str, str, str]], | |||||
| ) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]: | |||||
| ent2id: Dict[str, int] = {} | |||||
| rel2id: Dict[str, int] = {} | |||||
| mapped: List[Tuple[int, int, int]] = [] | |||||
| def _id(d: Dict[str, int], k: str) -> int: | |||||
| d.setdefault(k, len(d)) | |||||
| return d[k] | |||||
| for h, r, t in triples: | |||||
| mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t))) | |||||
| return mapped, ent2id, rel2id | |||||
| def load_kg_from_files( | |||||
| root: str, splits: tuple[str] = ("train", "valid", "test") | |||||
| ) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]: | |||||
| root = Path(root) | |||||
| split_raw, all_raw = {}, [] | |||||
| for s in splits: | |||||
| lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()] | |||||
| split_raw[s] = lines | |||||
| all_raw.extend(lines) | |||||
| mapped, ent2id, rel2id = _map_to_ids(all_raw) | |||||
| datasets, start = {}, 0 | |||||
| for s in splits: | |||||
| end = start + len(split_raw[s]) | |||||
| datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s) | |||||
| start = end | |||||
| return datasets, ent2id, rel2id |
| from typing import List, Tuple | |||||
| import torch | |||||
| from pykeen.datasets import OpenBioLink | |||||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
| class OpenBioLinkDataset(KGDataset): | |||||
| """A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| """ | |||||
| Initialize the WN18RR wrapper. | |||||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||||
| """ | |||||
| # Load the PyKeen WN18 dataset | |||||
| dataset = OpenBioLink() | |||||
| # Select the appropriate TriplesFactory based on the requested split | |||||
| if split == "train": | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
| # Convert each row into a Python tuple of ints | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| # Get the total number of entities and relations in the entire dataset | |||||
| num_entities = dataset.num_entities | |||||
| num_relations = dataset.num_relations | |||||
| # Initialize the base KGDataset with the selected triples | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=num_entities, | |||||
| num_relations=num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| data = super().__getitem__(idx) | |||||
| # Convert the tuple to a torch tensor | |||||
| return torch.tensor(data, dtype=torch.long) |
| from typing import List, Tuple | |||||
| import torch | |||||
| from pykeen.datasets import Wikidata5M | |||||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
| class OpenKEWikiDataset(KGDataset): | |||||
| """A wrapper around PyKeen's OpenKE-Wiki (WikiN3) dataset that produces KGDataset instances.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| """ | |||||
| Initialize the WN18RR wrapper. | |||||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||||
| """ | |||||
| # Load the PyKeen WN18 dataset | |||||
| dataset = Wikidata5M() | |||||
| # Select the appropriate TriplesFactory based on the requested split | |||||
| if split == "train": | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
| # Convert each row into a Python tuple of ints | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| # Get the total number of entities and relations in the entire dataset | |||||
| num_entities = dataset.num_entities | |||||
| num_relations = dataset.num_relations | |||||
| # Initialize the base KGDataset with the selected triples | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=num_entities, | |||||
| num_relations=num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| data = super().__getitem__(idx) | |||||
| # Convert the tuple to a torch tensor | |||||
| return torch.tensor(data, dtype=torch.long) |
| from typing import List, Tuple | |||||
| import torch | |||||
| from pykeen.datasets import WN18, WN18RR | |||||
| from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
| class WN18Dataset(KGDataset): | |||||
| """A wrapper around PyKeen's WN18 dataset that produces KGDataset instances.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| """ | |||||
| Initialize the WN18RR wrapper. | |||||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||||
| """ | |||||
| # Load the PyKeen WN18 dataset | |||||
| dataset = WN18() | |||||
| # Select the appropriate TriplesFactory based on the requested split | |||||
| if split == "train": | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
| # Convert each row into a Python tuple of ints | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| # Get the total number of entities and relations in the entire dataset | |||||
| num_entities = dataset.num_entities | |||||
| num_relations = dataset.num_relations | |||||
| # Initialize the base KGDataset with the selected triples | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=num_entities, | |||||
| num_relations=num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| data = super().__getitem__(idx) | |||||
| # Convert the tuple to a torch tensor | |||||
| return torch.tensor(data, dtype=torch.long) | |||||
| # Assuming KGDataset is defined in the same module or already imported: | |||||
| # from your_module import KGDataset | |||||
| class WN18RRDataset(KGDataset): | |||||
| """A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| """ | |||||
| Initialize the WN18RR wrapper. | |||||
| :param split: One of "train", "valid", or "test" indicating which split to load. | |||||
| """ | |||||
| # Load the PyKeen WN18RR dataset | |||||
| dataset = WN18RR() | |||||
| # Select the appropriate TriplesFactory based on the requested split | |||||
| if split == "train": | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
| # Convert each row into a Python tuple of ints | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| # Get the total number of entities and relations in the entire dataset | |||||
| num_entities = dataset.num_entities | |||||
| num_relations = dataset.num_relations | |||||
| # Initialize the base KGDataset with the selected triples | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=num_entities, | |||||
| num_relations=num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| data = super().__getitem__(idx) | |||||
| # Convert the tuple to a torch tensor | |||||
| return torch.tensor(data, dtype=torch.long) |
| from typing import List, Tuple | |||||
| import torch | |||||
| from pykeen.datasets import PathDataset | |||||
| from .kg_dataset import KGDataset | |||||
| class YAGO310Dataset(KGDataset): | |||||
| """Wrapper around PyKeen's YAGO3-10 dataset.""" | |||||
| def __init__(self, split: str = "train") -> None: | |||||
| # dataset = YAGO310() | |||||
| dataset = PathDataset( | |||||
| training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt", | |||||
| testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt", | |||||
| validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt", | |||||
| ) | |||||
| if split == "train": | |||||
| # if True: | |||||
| tf = dataset.training | |||||
| elif split in ("valid", "val", "validation"): | |||||
| tf = dataset.validation | |||||
| elif split in ("test", "testing"): | |||||
| tf = dataset.testing | |||||
| else: | |||||
| msg = f"Unrecognized split '{split}'" | |||||
| raise ValueError(msg) | |||||
| custom_triples = torch.load( | |||||
| "/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt", | |||||
| map_location=torch.device("cpu"), | |||||
| ).to(torch.long) | |||||
| # .tolist() | |||||
| # get a random sample of size 5000 | |||||
| # seed = torch.seed() # For reproducibility | |||||
| # torch.manual_seed(seed) | |||||
| # torch.cuda.manual_seed(seed) | |||||
| # if custom_triples.shape[0] > 5000: | |||||
| # custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]] | |||||
| tf = tf.clone_and_exchange_triples( | |||||
| custom_triples, | |||||
| ) | |||||
| # triples = tf.mapped_triples | |||||
| # # get a random sample of size 5000 | |||||
| # if triples.shape[0] > 5000: | |||||
| # indices = torch.randperm(triples.shape[0])[:5000] | |||||
| # triples = triples[indices] | |||||
| # tf = tf.clone_and_exchange_triples(triples) | |||||
| triples_list: List[Tuple[int, int, int]] = [ | |||||
| (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
| ] | |||||
| super().__init__( | |||||
| triples_factory=tf, | |||||
| triples=triples_list, | |||||
| num_entities=dataset.num_entities, | |||||
| num_relations=dataset.num_relations, | |||||
| split=split, | |||||
| ) | |||||
| def __getitem__(self, idx: int) -> torch.Tensor: | |||||
| return torch.tensor(super().__getitem__(idx), dtype=torch.long) |
| from data import ( | |||||
| FB15KDataset, | |||||
| WN18Dataset, | |||||
| WN18RRDataset, | |||||
| YAGO310Dataset, | |||||
| ) | |||||
| from metrics.wlcrec import WLCREC | |||||
| def main() -> None: | |||||
| datasets = { | |||||
| "YAGO3-10": YAGO310Dataset(split="train"), | |||||
| # "WN18": WN18Dataset(split="train"), | |||||
| # "WN18RR": WN18RRDataset(split="train"), | |||||
| # "FB15K": FB15KDataset(split="train"), | |||||
| # "Hetionet": HetionetDataset(split="train"), | |||||
| # "OpenBioLink": OpenBioLinkDataset(split="train"), | |||||
| # "OpenKEWiki": OpenKEWikiDataset(split="train"), | |||||
| } | |||||
| for name, dataset in datasets.items(): | |||||
| wl_crec = WLCREC(dataset) | |||||
| entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = wl_crec.compute(H=20) | |||||
| print( | |||||
| f"\nDataset: {name}", | |||||
| f"\nResults (H={20})", | |||||
| f"\n • avg WLEC : {entropy:.6f} nats", | |||||
| f"\n • C_ratio : {c_ratio:.6f}", | |||||
| f"\n • C_NWLEC : {c_nwlec:.6f}", | |||||
| f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||||
| f"\n • D_ratio : {d_ratio:.6f}", | |||||
| f"\n • D_NWLEC : {d_nwlec:.6f}", | |||||
| sep="", | |||||
| ) | |||||
| if __name__ == "__main__": | |||||
| main() |
| from copy import deepcopy | |||||
| import hydra | |||||
| import lovely_tensors as lt | |||||
| import torch | |||||
| from hydra.utils import instantiate | |||||
| from omegaconf import DictConfig | |||||
| from metrics.c_swklf import CSWKLF | |||||
| from tools import get_pretty_logger | |||||
| # (Import whatever Trainer or runner you have) | |||||
| from training.trainer import Trainer | |||||
| logger = get_pretty_logger(__name__) | |||||
| lt.monkey_patch() | |||||
| @hydra.main(config_path="configs", config_name="config") | |||||
| def main(cfg: DictConfig) -> None: | |||||
| # Detect CUDA | |||||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
| common_params = cfg.common | |||||
| training_params = deepcopy(cfg.training) | |||||
| # drop the model_trainer from cfg.trainer | |||||
| del training_params.model_trainer | |||||
| common_params = instantiate(common_params) | |||||
| training_params = instantiate(training_params) | |||||
| # dataset instantiation | |||||
| train_dataset = instantiate(cfg.data.train) | |||||
| val_dataset = instantiate(cfg.data.valid) | |||||
| model = instantiate( | |||||
| cfg.model, | |||||
| num_entities=train_dataset.num_entities, | |||||
| num_relations=train_dataset.num_relations, | |||||
| triples_factory=train_dataset.triples_factory, | |||||
| device=device, | |||||
| ) | |||||
| model = model.to(device) | |||||
| model_trainer = instantiate( | |||||
| cfg.training.model_trainer, | |||||
| model=model, | |||||
| training_params=training_params, | |||||
| triples_factory=train_dataset.triples_factory, | |||||
| mapped_triples=train_dataset.triples, | |||||
| device=model.device, | |||||
| ) | |||||
| # Instantiate the Trainer with the model_trainer | |||||
| trainer = Trainer( | |||||
| train_dataset=train_dataset, | |||||
| val_dataset=val_dataset, | |||||
| model_trainer=model_trainer, | |||||
| common_params=common_params, | |||||
| training_params=training_params, | |||||
| device=model.device, | |||||
| ) | |||||
| if cfg.common.evaluate_only: | |||||
| trainer.evaluate() | |||||
| c_swklf_metric = CSWKLF( | |||||
| dataset=val_dataset, | |||||
| model=model, | |||||
| ) | |||||
| c_swklf_value = c_swklf_metric.compute() | |||||
| logger.info(f"CSWKLF Metric Value: {c_swklf_value}") | |||||
| else: | |||||
| trainer.run() | |||||
| if __name__ == "__main__": | |||||
| main() |
| from __future__ import annotations | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import TYPE_CHECKING, Any, List, Tuple | |||||
| if TYPE_CHECKING: | |||||
| from data.kg_dataset import KGDataset | |||||
| class BaseMetric(ABC): | |||||
| """Base class for metrics that compute Weisfeiler-Lehman Entropy | |||||
| Complexity (WLEC) or similar metrics. | |||||
| """ | |||||
| def __init__(self, dataset: KGDataset) -> None: | |||||
| self.dataset = dataset | |||||
| self.num_entities = dataset.num_entities | |||||
| self.out_adj, self.in_adj = self._build_adjacency_lists() | |||||
| def _build_adjacency_lists( | |||||
| self, | |||||
| ) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]: | |||||
| """Create *in* and *out* adjacency lists from mapped triples. | |||||
| Parameters | |||||
| ---------- | |||||
| triples: | |||||
| Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``). | |||||
| num_entities: | |||||
| Total number of entities. Needed to allocate adjacency containers. | |||||
| Returns | |||||
| ------- | |||||
| out_adj, in_adj | |||||
| Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element | |||||
| ``in_adj[v]`` is a list ``[(r, h), ...]``. | |||||
| """ | |||||
| triples = self.dataset.triples # (m, 3) | |||||
| num_entities = self.num_entities | |||||
| out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||||
| in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||||
| for h, r, t in triples.tolist(): | |||||
| out_adj[h].append((r, t)) | |||||
| in_adj[t].append((r, h)) | |||||
| return out_adj, in_adj | |||||
| @abstractmethod | |||||
| def compute( | |||||
| self, | |||||
| *args, | |||||
| **kwargs, | |||||
| ) -> Any: ... |
| from __future__ import annotations | |||||
| import math | |||||
| from collections import defaultdict | |||||
| from dataclasses import dataclass | |||||
| from typing import TYPE_CHECKING, Callable, Iterable | |||||
| import torch | |||||
| from torch import Tensor | |||||
| from .base_metric import BaseMetric | |||||
| if TYPE_CHECKING: | |||||
| from data.kg_dataset import KGDataset | |||||
| from models.base_model import ERModel | |||||
| @torch.no_grad() | |||||
| def _log_prob_distribution( | |||||
| fn: Callable[[Tensor], Tensor], | |||||
| query: Tensor, | |||||
| batch_size: int, | |||||
| ) -> Iterable[Tensor]: | |||||
| """ | |||||
| Yield *log*-probability batches for an arbitrary distribution helper. | |||||
| Parameters | |||||
| ---------- | |||||
| fn : | |||||
| A bound method of *model* returning **log**-probabilities, e.g. | |||||
| ``model.tail_distribution(..., log=True)``. | |||||
| query : | |||||
| LongTensor of shape (N, 2) containing the query IDs that `fn` expects. | |||||
| batch_size : | |||||
| Mini-batch size. Choose according to available GPU/CPU memory. | |||||
| """ | |||||
| for i in range(0, len(query), batch_size): | |||||
| yield fn(query[i : i + batch_size]) # (B, |Candidates|) | |||||
| def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor: | |||||
| """ | |||||
| KL(q || p) where q is uniform over *true_index_sets*. | |||||
| Parameters | |||||
| ---------- | |||||
| log_p : | |||||
| Tensor of shape (B, N) — log-probabilities from the model. | |||||
| true_index_sets : | |||||
| List (length B). Element *b* contains the list of **indices** | |||||
| that are correct for row *b*. | |||||
| Returns | |||||
| ------- | |||||
| kl : Tensor, shape (B,) — KL divergence per row (natural log). | |||||
| """ | |||||
| rows: list[Tensor] = [] | |||||
| for lp_row, idx in zip(log_p, true_index_sets): | |||||
| k = len(idx) | |||||
| rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p] | |||||
| return torch.tensor(rows, device=log_p.device) | |||||
| @dataclass(slots=True) | |||||
| class _Accum: | |||||
| """Simple accumulator for ∑ value and ∑ weight.""" | |||||
| tot_v: float = 0.0 | |||||
| tot_w: float = 0.0 | |||||
| def update(self, value: float, weight: float) -> None: | |||||
| self.tot_v += value * weight | |||||
| self.tot_w += weight | |||||
| @property | |||||
| def mean(self) -> float: | |||||
| return self.tot_v / self.tot_w if self.tot_w else float("nan") | |||||
| class CSWKLF(BaseMetric): | |||||
| """ | |||||
| C-SWKLF metric for evaluating knowledge graph embeddings. | |||||
| This metric is used to evaluate the quality of knowledge graph embeddings | |||||
| by computing the C-SWKLF score. | |||||
| """ | |||||
| def __init__(self, dataset: KGDataset, model: ERModel) -> None: | |||||
| super().__init__(dataset) | |||||
| self.model = model | |||||
| @torch.no_grad() | |||||
| def compute( | |||||
| self, | |||||
| alpha: float = 1 / 3, | |||||
| beta: float = 1 / 3, | |||||
| gamma: float = 1 / 3, | |||||
| batch_size: int = 1024, | |||||
| slice_size: int | None = 2048, | |||||
| device: torch.device | str | None = None, | |||||
| ) -> float: | |||||
| """ | |||||
| Compute *Comprehensive Structural-Weighted KL-Fitness*. | |||||
| Parameters | |||||
| ---------- | |||||
| model : | |||||
| A trained PyKEEN ERModel **with** the three distribution helpers. | |||||
| kg : | |||||
| `pykeen.triples.KGInfo` instance providing `mapped_triples`. | |||||
| alpha, beta, gamma : | |||||
| Non-negative weights for the three query types, default 1/3 each. | |||||
| Must sum to *1*. | |||||
| batch_size : | |||||
| Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM. | |||||
| slice_size : | |||||
| Forwarded to `tail_distribution` / `head_distribution` / | |||||
| `relation_distribution`. `None` disables slicing. | |||||
| device : | |||||
| Where to do the computation. `None` ⇒ first param's device. | |||||
| Returns | |||||
| ------- | |||||
| score : float in (0, 1] | |||||
| Higher = model closer to empirical distributions across *all* directions. | |||||
| """ | |||||
| # --------------------------------------------------------------------- # | |||||
| # Preparation # | |||||
| # --------------------------------------------------------------------- # | |||||
| assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1." | |||||
| model_device = next(self.model.parameters()).device | |||||
| if device is None: | |||||
| device = model_device | |||||
| triples = self.dataset.triples.to(device) # (|T|, 3) | |||||
| heads, rels, tails = triples.t() # (|T|,) | |||||
| # Build index structures -------------------------------------------------- | |||||
| tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list) | |||||
| heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list) | |||||
| rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list) | |||||
| for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()): | |||||
| tails_by_hr[(h, r)].append(t) | |||||
| heads_by_rt[(r, t)].append(h) | |||||
| rels_by_ht[(h, t)].append(r) | |||||
| # Structural entropies & query lists --------------------------------------- | |||||
| H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r | |||||
| H_head: dict[int, _Accum] = defaultdict(_Accum) | |||||
| tail_queries: list[tuple[int, int]] = [] # (h,r) | |||||
| head_queries: list[tuple[int, int]] = [] # (r,t) | |||||
| rel_queries: list[tuple[int, int]] = [] # (h,t) | |||||
| # --- tails -------------------------------------------------------------- | |||||
| for (h, r), ts in tails_by_hr.items(): | |||||
| k = len(ts) | |||||
| H_tail[r].update(math.log(k), 1.0) | |||||
| tail_queries.append((h, r)) | |||||
| # --- heads -------------------------------------------------------------- | |||||
| for (r, t), hs in heads_by_rt.items(): | |||||
| k = len(hs) | |||||
| H_head[r].update(math.log(k), 1.0) | |||||
| head_queries.append((r, t)) | |||||
| # --- relations ---------------------------------------------------------- | |||||
| H_rel_accum = _Accum() | |||||
| for (h, t), rs in rels_by_ht.items(): | |||||
| k = len(rs) | |||||
| H_rel_accum.update(math.log(k), 1.0) | |||||
| rel_queries.append((h, t)) | |||||
| H_struct_tail = {r: acc.mean for r, acc in H_tail.items()} | |||||
| H_struct_head = {r: acc.mean for r, acc in H_head.items()} | |||||
| # H_struct_rel = H_rel_accum.mean # scalar | |||||
| # --------------------------------------------------------------------- # | |||||
| # KL divergences # | |||||
| # --------------------------------------------------------------------- # | |||||
| def _avg_kl_per_relation( | |||||
| queries: list[tuple[int, int]], | |||||
| true_idx_map: dict[tuple[int, int], list[int]], | |||||
| distribution_fn: Callable[[Tensor], Tensor], | |||||
| num_relations: bool = False, # switch to know output size | |||||
| ) -> dict[int, float]: | |||||
| """ | |||||
| Compute *average* KL(q || p) **per relation**. | |||||
| Works for the tail & head conditionals. | |||||
| """ | |||||
| kl_sum: dict[int, float] = defaultdict(float) | |||||
| count: dict[int, int] = defaultdict(int) | |||||
| query_tensor = torch.tensor(queries, dtype=torch.long, device=device) | |||||
| for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size): | |||||
| # slice_size handled inside distribution_fn | |||||
| # log_p : (B, |Candidates|) | |||||
| b = log_p.shape[0] | |||||
| # Map back to current slice of queries | |||||
| slice_queries = queries[:b] | |||||
| queries[:] = queries[b:] # consume | |||||
| true_sets = [true_idx_map[q] for q in slice_queries] | |||||
| kl_row = _kl_uniform(log_p, true_sets) # (B,) | |||||
| for (x, r), kl_val in zip(slice_queries, kl_row.tolist()): | |||||
| rel_id = r if not num_relations else x | |||||
| kl_sum[rel_id] += kl_val | |||||
| count[rel_id] += 1 | |||||
| return {r: kl_sum[r] / count[r] for r in kl_sum} | |||||
| # --- tails -------------------------------------------------------------- | |||||
| tail_queries_copy = tail_queries.copy() | |||||
| D_tail = _avg_kl_per_relation( | |||||
| tail_queries_copy, | |||||
| tails_by_hr, | |||||
| lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s), | |||||
| ) | |||||
| S_tail = {r: math.exp(-d) for r, d in D_tail.items()} | |||||
| # --- heads -------------------------------------------------------------- | |||||
| head_queries_copy = head_queries.copy() | |||||
| D_head = _avg_kl_per_relation( | |||||
| head_queries_copy, | |||||
| heads_by_rt, | |||||
| lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s), | |||||
| num_relations=True, | |||||
| ) | |||||
| S_head = {r: math.exp(-d) for r, d in D_head.items()} | |||||
| # --- relations (single global number) ---------------------------------- | |||||
| D_rel_sum = 0.0 | |||||
| for log_p in _log_prob_distribution( | |||||
| lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s), | |||||
| torch.tensor(rel_queries, dtype=torch.long, device=device), | |||||
| batch_size, | |||||
| ): | |||||
| slice_size_batch = log_p.shape[0] | |||||
| slice_queries = rel_queries[:slice_size_batch] | |||||
| rel_queries[:] = rel_queries[slice_size_batch:] | |||||
| true_sets = [rels_by_ht[q] for q in slice_queries] | |||||
| D_rel_sum += _kl_uniform(log_p, true_sets).sum().item() | |||||
| D_rel = D_rel_sum / H_rel_accum.tot_w | |||||
| S_rel = math.exp(-D_rel) | |||||
| # --------------------------------------------------------------------- # | |||||
| # Weighted aggregation → C-SWKLF # | |||||
| # --------------------------------------------------------------------- # | |||||
| def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float: | |||||
| num = sum(weight[r] * score[r] for r in score) | |||||
| den = sum(weight[r] for r in score) | |||||
| return num / den if den else float("nan") | |||||
| sw_tail = _weighted_mean(S_tail, H_struct_tail) | |||||
| sw_head = _weighted_mean(S_head, H_struct_head) | |||||
| # Relations: numerator & denominator cancel | |||||
| sw_rel = S_rel | |||||
| return alpha * sw_tail + beta * sw_head + gamma * sw_rel |
| # from __future__ import annotations | |||||
| # import math | |||||
| # import os | |||||
| # import random | |||||
| # import threading | |||||
| # from collections import defaultdict | |||||
| # from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| # from typing import Dict, List, Tuple | |||||
| # import torch | |||||
| # from data.kg_dataset import KGDataset | |||||
| # from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in | |||||
| # from tools import get_pretty_logger | |||||
| # logger = get_pretty_logger(__name__) | |||||
| # # -------------------------------------------------------------------- | |||||
| # # Edge-editing primitives | |||||
| # # -------------------------------------------------------------------- | |||||
| # # -- additions -------------------- | |||||
| # def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||||
| # for _ in range(1000): | |||||
| # h = rng.randrange(n_ent) | |||||
| # t = rng.randrange(n_ent) | |||||
| # if colours[h] != colours[t]: | |||||
| # r = rng.randrange(n_rel) | |||||
| # triples.append((h, r, t)) | |||||
| # out_adj[h].append((r, t)) | |||||
| # in_adj[t].append((r, h)) | |||||
| # return | |||||
| # def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||||
| # σ = rng.choice(colours) | |||||
| # τ = rng.choice(colours) | |||||
| # h = colours.index(σ) | |||||
| # t = colours.index(τ) | |||||
| # triples.append((h, 0, t)) | |||||
| # out_adj[h].append((0, t)) | |||||
| # in_adj[t].append((0, h)) | |||||
| # # -- removals -------------------- | |||||
| # def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
| # cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||||
| # if cand: | |||||
| # idx = rng.choice(cand) | |||||
| # h, r, t = triples.pop(idx) | |||||
| # out_adj[h].remove((r, t)) | |||||
| # in_adj[t].remove((r, h)) | |||||
| # def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
| # sig_rel = defaultdict(set) | |||||
| # for h, r, t in triples: | |||||
| # σ, τ = colours[h], colours[t] | |||||
| # sig_rel[(σ, τ)].add(r) | |||||
| # cand = [] | |||||
| # for i, (h, r, t) in enumerate(triples): | |||||
| # if len(sig_rel[(colours[h], colours[t])]) == 1: | |||||
| # cand.append(i) | |||||
| # if cand: | |||||
| # idx = rng.choice(cand) | |||||
| # h, r, t = triples.pop(idx) | |||||
| # out_adj[h].remove((r, t)) | |||||
| # in_adj[t].remove((r, h)) | |||||
| # def _search_worker( | |||||
| # seed: int, | |||||
| # triples_init: List[Tuple[int, int, int]], | |||||
| # triples_factory, | |||||
| # n_ent: int, | |||||
| # n_rel: int, | |||||
| # depth: int, | |||||
| # lo: float, | |||||
| # hi: float, | |||||
| # max_iters: int, | |||||
| # ) -> List[Tuple[int, int, int]] | None: | |||||
| # """ | |||||
| # Run the exact same hill‑climb that `tune_crec()` did, | |||||
| # but entirely in this process. Return the edited triples | |||||
| # once c falls in [lo, hi]; return None if we used up all iterations. | |||||
| # """ | |||||
| # rng = random.Random(seed) | |||||
| # triples = triples_init.copy().tolist() | |||||
| # for it in range(max_iters): | |||||
| # # WL‑CREC is *only* recomputed every 1000 edits, exactly like before | |||||
| # if it % 1000 == 0: | |||||
| # dataset = KGDataset( | |||||
| # triples_factory, | |||||
| # triples=torch.tensor(triples, dtype=torch.long), | |||||
| # num_entities=n_ent, | |||||
| # num_relations=n_rel, | |||||
| # ) | |||||
| # wl = WLCREC(dataset) | |||||
| # colours = wl.wl_colours(depth) | |||||
| # *_, c, _ = wl.compute(H=5) # unchanged API | |||||
| # if lo <= c <= hi: # success | |||||
| # logger.info("[seed %d] hit %.4f in %d edits", seed, c, it) | |||||
| # return triples | |||||
| # # ---------- identical edit logic ---------- | |||||
| # if c < lo: | |||||
| # if rng.random() < 0.5: | |||||
| # rem_det_edge(triples, colours, depth, rng) | |||||
| # else: | |||||
| # add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||||
| # else: # c > hi | |||||
| # if rng.random() < 0.5: | |||||
| # rem_div_edge(triples, colours, depth, rng) | |||||
| # else: | |||||
| # add_det_edge(triples, colours, depth, n_rel, rng) | |||||
| # # ------------------------------------------ | |||||
| # return None # used up our budget | |||||
| # # -------------------------------------------------------------------- | |||||
| # # Unified tuner | |||||
| # # -------------------------------------------------------------------- | |||||
| # def tune_crec( | |||||
| # wl_crec: WLCREC, | |||||
| # n_ent: int, | |||||
| # n_rel: int, | |||||
| # depth: int, | |||||
| # lo: float, | |||||
| # hi: float, | |||||
| # max_iters: int = 80_000, | |||||
| # seed: int = 42, | |||||
| # ): | |||||
| # triples = wl_crec.dataset.triples.tolist() | |||||
| # rng = random.Random(seed) | |||||
| # for it in range(max_iters): | |||||
| # print(f"\r[iter {it + 1:5d}] ", end="") | |||||
| # if it % 1000 == 0: | |||||
| # dataset = KGDataset( | |||||
| # wl_crec.dataset.triples_factory, | |||||
| # triples=torch.tensor(triples, dtype=torch.long), | |||||
| # num_entities=n_ent, | |||||
| # num_relations=n_rel, | |||||
| # ) | |||||
| # tmp_wl_crec = WLCREC(dataset) | |||||
| # # colours = wl_colours(triples, n_ent, depth) | |||||
| # colours = tmp_wl_crec.wl_colours(depth) | |||||
| # _, _, _, _, c, _ = tmp_wl_crec.compute(H=5) | |||||
| # if lo <= c <= hi: | |||||
| # logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples)) | |||||
| # return triples | |||||
| # if c < lo: | |||||
| # # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying | |||||
| # if rng.random() < 0.5: | |||||
| # rem_det_edge(triples, colours, depth, rng) | |||||
| # else: | |||||
| # add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||||
| # # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic | |||||
| # elif rng.random() < 0.5: | |||||
| # rem_div_edge(triples, colours, depth, rng) | |||||
| # else: | |||||
| # add_det_edge(triples, colours, depth, n_rel, rng) | |||||
| # if (it + 1) % 10000 == 1: | |||||
| # logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples)) | |||||
| # raise RuntimeError("Exceeded max iterations without hitting target band.") | |||||
| # def _edit_batch( | |||||
| # worker_id: int, | |||||
| # n_edits: int, | |||||
| # colours: List[int], | |||||
| # crec: float, | |||||
| # target_lo: float, | |||||
| # target_hi: float, | |||||
| # depth: int, | |||||
| # n_ent: int, | |||||
| # n_rel: int, | |||||
| # seed_base: int, | |||||
| # ): | |||||
| # """ | |||||
| # Perform `n_edits` topology modifications *locally* and return | |||||
| # (added_triples, removed_triples) lists. | |||||
| # """ | |||||
| # rng = random.Random(seed_base + worker_id) | |||||
| # local_added, local_removed = [], [] | |||||
| # # local graph views: only sizes matter for edit selection | |||||
| # for _ in range(n_edits): | |||||
| # if crec < target_lo: # need ↑ WL-CREC | |||||
| # if rng.random() < 0.5: # • remove deterministic | |||||
| # # We cannot remove by index safely without the global list; | |||||
| # # choose a *signature* and remember intention to delete: | |||||
| # local_removed.append(("det", rng.random())) | |||||
| # else: # • add diversifying | |||||
| # h = rng.randrange(n_ent) | |||||
| # t = rng.randrange(n_ent) | |||||
| # while colours[h] == colours[t]: | |||||
| # h = rng.randrange(n_ent) | |||||
| # t = rng.randrange(n_ent) | |||||
| # r = rng.randrange(n_rel) | |||||
| # local_added.append((h, r, t)) | |||||
| # else: # need ↓ WL-CREC | |||||
| # if rng.random() < 0.5: # • remove diversifying | |||||
| # local_removed.append(("div", rng.random())) | |||||
| # else: # • add deterministic | |||||
| # σ = rng.choice(colours) | |||||
| # τ = rng.choice(colours) | |||||
| # h = colours.index(σ) | |||||
| # t = colours.index(τ) | |||||
| # local_added.append((h, 0, t)) | |||||
| # return local_added, local_removed | |||||
| # def wl_colours(triples, n_ent, depth): | |||||
| # # build adjacency once | |||||
| # out_adj, in_adj = defaultdict(list), defaultdict(list) | |||||
| # for h, r, t in triples: | |||||
| # out_adj[h].append((r, t)) | |||||
| # in_adj[t].append((r, h)) | |||||
| # colours_rounds = [[0] * n_ent] # round-0 colours | |||||
| # for h in range(1, depth + 1): | |||||
| # prev = colours_rounds[-1] | |||||
| # # 1) build textual signatures in parallel | |||||
| # def sig(v): | |||||
| # neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||||
| # ("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||||
| # ] | |||||
| # neigh.sort() | |||||
| # return (prev[v], tuple(neigh)) | |||||
| # with ThreadPoolExecutor() as tpe: # cheap threads inside worker | |||||
| # sigs = list(tpe.map(sig, range(n_ent))) | |||||
| # # 2) assign deterministic colour IDs | |||||
| # sig2id: Dict[Tuple, int] = {} | |||||
| # next_round = [0] * n_ent | |||||
| # fresh = 0 | |||||
| # for v, sg in enumerate(sigs): | |||||
| # cid = sig2id.setdefault(sg, fresh) | |||||
| # if cid == fresh: | |||||
| # fresh += 1 | |||||
| # next_round[v] = cid | |||||
| # colours_rounds.append(next_round) | |||||
| # depth_colours = colours_rounds[-1] | |||||
| # return depth_colours | |||||
| # def _metric_worker(args): | |||||
| # triples, n_ent, n_rel, depth = args | |||||
| # dataset = KGDataset( | |||||
| # triples_factory=None, # Not used in this context | |||||
| # triples=torch.tensor(triples, dtype=torch.long), | |||||
| # num_entities=n_ent, | |||||
| # num_relations=n_rel, | |||||
| # ) | |||||
| # wl_crec = WLCREC(dataset) | |||||
| # _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False) | |||||
| # colours = wl_colours(triples, n_ent, depth) | |||||
| # return c, colours | |||||
| # def tune_crec_parallel_edits( | |||||
| # triples_init, | |||||
| # n_ent: int, | |||||
| # n_rel: int, | |||||
| # depth: int, | |||||
| # target_lo: float, | |||||
| # target_hi: float, | |||||
| # max_iters: int = 80_000, | |||||
| # metric_every: int = 100, | |||||
| # n_workers: int = max(20, math.ceil(os.cpu_count() / 2)), | |||||
| # seed: int = 42, | |||||
| # ): | |||||
| # # -------- shared mutable state (main thread owns it) -------- | |||||
| # triples = triples_init.tolist() | |||||
| # out_adj, in_adj = defaultdict(list), defaultdict(list) | |||||
| # for h, r, t in triples: | |||||
| # out_adj[h].append((r, t)) | |||||
| # in_adj[t].append((r, h)) | |||||
| # pool = ThreadPoolExecutor(max_workers=n_workers) | |||||
| # metric_lock = threading.Lock() # exactly one metric at a time | |||||
| # rng_global = random.Random(seed) | |||||
| # # ----- first metric checkpoint ----- | |||||
| # crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||||
| # edit_budget_total = 0 | |||||
| # for it in range(0, max_iters, metric_every): | |||||
| # # ========================================================= | |||||
| # # 1. PARALLEL EDIT STAGE (metric_every edits in total) | |||||
| # # ========================================================= | |||||
| # futures = [] | |||||
| # edits_per_worker = metric_every // n_workers | |||||
| # extra = metric_every % n_workers | |||||
| # for wid in range(n_workers): | |||||
| # n_edits = edits_per_worker + (1 if wid < extra else 0) | |||||
| # futures.append( | |||||
| # pool.submit( | |||||
| # _edit_batch, | |||||
| # wid, | |||||
| # n_edits, | |||||
| # colours, | |||||
| # crec, | |||||
| # target_lo, | |||||
| # target_hi, | |||||
| # depth, | |||||
| # n_ent, | |||||
| # n_rel, | |||||
| # seed, | |||||
| # ) | |||||
| # ) | |||||
| # # merge when workers finish | |||||
| # for fut in as_completed(futures): | |||||
| # added, removed_specs = fut.result() | |||||
| # # --- apply additions immediately (cheap, conflict-free) | |||||
| # for h, r, t in added: | |||||
| # triples.append((h, r, t)) | |||||
| # out_adj[h].append((r, t)) | |||||
| # in_adj[t].append((r, h)) | |||||
| # # --- apply removals: interpret the spec on *current* graph | |||||
| # for typ, randv in removed_specs: | |||||
| # if typ == "div": | |||||
| # idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||||
| # else: # 'det' | |||||
| # sig_rel = defaultdict(set) | |||||
| # for h, r, t in triples: | |||||
| # sig_rel[(colours[h], colours[t])].add(r) | |||||
| # idxs = [ | |||||
| # i | |||||
| # for i, (h, r, t) in enumerate(triples) | |||||
| # if len(sig_rel[(colours[h], colours[t])]) == 1 | |||||
| # ] | |||||
| # if idxs: | |||||
| # victim = idxs[int(randv * len(idxs))] | |||||
| # h, r, t = triples.pop(victim) | |||||
| # out_adj[h].remove((r, t)) | |||||
| # in_adj[t].remove((r, h)) | |||||
| # edit_budget_total += metric_every | |||||
| # # ========================================================= | |||||
| # # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised) | |||||
| # # ========================================================= | |||||
| # # if edit_budget_total % (10 * metric_every) == 0: | |||||
| # if True: | |||||
| # with metric_lock: | |||||
| # crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||||
| # logging.info( | |||||
| # "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples) | |||||
| # ) | |||||
| # if target_lo <= crec <= target_hi: | |||||
| # logging.info("Target band reached.") | |||||
| # pool.shutdown(wait=True) | |||||
| # return triples | |||||
| # pool.shutdown(wait=True) | |||||
| # raise RuntimeError("Exceeded max iteration budget without success.") | |||||
| # from __future__ import annotations | |||||
| # import logging | |||||
| # import random | |||||
| # from collections import defaultdict | |||||
| # from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
| # from typing import List, Tuple | |||||
| # import torch | |||||
| # from data.kg_dataset import KGDataset | |||||
| # from metrics.wlcrec import WLCREC | |||||
| # from tools import get_pretty_logger | |||||
| # logger = get_pretty_logger(__name__) | |||||
| # # ------------ additions ------------ | |||||
| # def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random): | |||||
| # for _ in range(1000): | |||||
| # h, t = rng.randrange(n_ent), rng.randrange(n_ent) | |||||
| # if colours[h] != colours[t]: | |||||
| # r = rng.randrange(n_rel) | |||||
| # return ("add", (h, r, t)) | |||||
| # return None # fell through – extremely rare | |||||
| # def propose_add_det(colours, n_rel: int, rng: random.Random): | |||||
| # σ, τ = rng.choice(colours), rng.choice(colours) | |||||
| # h, t = colours.index(σ), colours.index(τ) | |||||
| # return ("add", (h, 0, t)) # rel 0 = deterministic | |||||
| # # ------------ removals ------------ | |||||
| # def propose_rem_div(triples: List, colours, rng: random.Random): | |||||
| # cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]] | |||||
| # return ("rem", rng.choice(cand)) if cand else None | |||||
| # def propose_rem_det(triples: List, colours, rng: random.Random): | |||||
| # sig_rel = defaultdict(set) | |||||
| # for h, r, t in triples: | |||||
| # sig_rel[(colours[h], colours[t])].add(r) | |||||
| # cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1] | |||||
| # return ("rem", rng.choice(cand)) if cand else None | |||||
| # def make_edit_proposal( | |||||
| # triples_snapshot: List, | |||||
| # colours_snapshot, | |||||
| # c: float, | |||||
| # lo: float, | |||||
| # hi: float, | |||||
| # n_ent: int, | |||||
| # n_rel: int, | |||||
| # depth: int, # still here for future use / signatures | |||||
| # seed: int, | |||||
| # ): | |||||
| # """Return exactly one ('add' | 'rem', triple) proposal or None.""" | |||||
| # rng = random.Random(seed) | |||||
| # # -- decide which kind of edit we want, *given* the current c ------------ | |||||
| # if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying) | |||||
| # chooser = (propose_rem_det, propose_add_div) | |||||
| # else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic) | |||||
| # chooser = (propose_rem_div, propose_add_det) | |||||
| # op = rng.choice(chooser) | |||||
| # return ( | |||||
| # op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng) | |||||
| # if op.__name__.startswith("propose_add") | |||||
| # else op(triples_snapshot, colours_snapshot, rng) | |||||
| # ) | |||||
| # def tune_crec_parallel_edits( | |||||
| # triples: List, | |||||
| # n_ent: int, | |||||
| # n_rel: int, | |||||
| # depth: int, | |||||
| # lo: float, | |||||
| # hi: float, | |||||
| # *, | |||||
| # max_iters: int = 80_000, | |||||
| # edits_per_eval: int = 1000, # == old “if it % 1000 == 0” | |||||
| # batch_size: int = 256, # how many proposals we farm out at once | |||||
| # max_workers: int = 4, | |||||
| # seed: int = 42, | |||||
| # ) -> List: | |||||
| # rng_global = random.Random(seed) | |||||
| # triples = set(triples) # deduplicate if needed | |||||
| # with ThreadPoolExecutor(max_workers=max_workers) as pool: | |||||
| # proposal_seed = seed * 997 # deterministic but different stream | |||||
| # # edit_counter = 0 | |||||
| # # while edit_counter < max_iters: | |||||
| # # # ----------------- expensive part (single‑thread) ---------------- | |||||
| # # dataset = KGDataset( | |||||
| # # triples_factory=None, # Not used in this context | |||||
| # # triples=torch.tensor(triples, dtype=torch.long), | |||||
| # # num_entities=n_ent, | |||||
| # # num_relations=n_rel, | |||||
| # # ) | |||||
| # # tmp = WLCREC(dataset) | |||||
| # # colours = tmp.wl_colours(depth) | |||||
| # # *_, c, _ = tmp.compute(H=5) | |||||
| # # # ----------------------------------------------------------------- | |||||
| # # if lo <= c <= hi: | |||||
| # # logger.info( | |||||
| # # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples) | |||||
| # # ) | |||||
| # # return triples | |||||
| # # ============ parallel block: just make `edits_per_eval` proposals | |||||
| # needed = min(edits_per_eval, max_iters - edit_counter) | |||||
| # proposals = [] | |||||
| # while len(proposals) < needed: | |||||
| # # launch a batch of workers | |||||
| # futs = [ | |||||
| # pool.submit( | |||||
| # make_edit_proposal, | |||||
| # triples, | |||||
| # colours, | |||||
| # c, | |||||
| # lo, | |||||
| # hi, | |||||
| # n_ent, | |||||
| # n_rel, | |||||
| # depth, | |||||
| # proposal_seed + i, | |||||
| # ) | |||||
| # for i in range(batch_size) | |||||
| # ] | |||||
| # for f in as_completed(futs): | |||||
| # prop = f.result() | |||||
| # if prop is not None: | |||||
| # proposals.append(prop) | |||||
| # if len(proposals) == needed: | |||||
| # break | |||||
| # proposal_seed += batch_size # move RNG window forward | |||||
| # # -------------- apply the gathered proposals *sequentially* ------- | |||||
| # for kind, trp in proposals: | |||||
| # if kind == "add": | |||||
| # triples.append(trp) | |||||
| # else: # "rem" | |||||
| # try: | |||||
| # triples.remove(trp) | |||||
| # except ValueError: | |||||
| # pass # already gone – benign collision | |||||
| # # ----------------------------------------------------------------- | |||||
| # edit_counter += needed | |||||
| # if edit_counter % 1_000 == 0: | |||||
| # logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples)) | |||||
| # raise RuntimeError("Exceeded max_iters without hitting target band.") | |||||
| import os | |||||
| import random | |||||
| import threading | |||||
| from collections import defaultdict | |||||
| from typing import Dict, List, Tuple | |||||
| # -------------------------------------------------------------------- | |||||
| # Logging helper (replace with your own if you prefer) -------------- | |||||
| # -------------------------------------------------------------------- | |||||
| from tools import get_pretty_logger | |||||
| logging = get_pretty_logger(__name__) | |||||
| # -------------------------------------------------------------------- | |||||
| # Original edge‑editing primitives (unchanged) ---------------------- | |||||
| # -------------------------------------------------------------------- | |||||
| def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||||
| for _ in range(1000): | |||||
| h = rng.randrange(n_ent) | |||||
| t = rng.randrange(n_ent) | |||||
| if colours[h] != colours[t]: | |||||
| r = rng.randrange(n_rel) | |||||
| triples.append((h, r, t)) | |||||
| out_adj[h].append((r, t)) | |||||
| in_adj[t].append((r, h)) | |||||
| return | |||||
| def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||||
| σ = rng.choice(colours) | |||||
| τ = rng.choice(colours) | |||||
| h = colours.index(σ) | |||||
| t = colours.index(τ) | |||||
| triples.append((h, 0, t)) | |||||
| out_adj[h].append((0, t)) | |||||
| in_adj[t].append((0, h)) | |||||
| def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
| cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||||
| if cand: | |||||
| idx = rng.choice(cand) | |||||
| h, r, t = triples.pop(idx) | |||||
| out_adj[h].remove((r, t)) | |||||
| in_adj[t].remove((r, h)) | |||||
| def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
| sig_rel = defaultdict(set) | |||||
| for h, r, t in triples: | |||||
| σ, τ = colours[h], colours[t] | |||||
| sig_rel[(σ, τ)].add(r) | |||||
| cand = [] | |||||
| for i, (h, r, t) in enumerate(triples): | |||||
| if len(sig_rel[(colours[h], colours[t])]) == 1: | |||||
| cand.append(i) | |||||
| if cand: | |||||
| idx = rng.choice(cand) | |||||
| h, r, t = triples.pop(idx) | |||||
| out_adj[h].remove((r, t)) | |||||
| in_adj[t].remove((r, h)) | |||||
| def _worker( | |||||
| *, | |||||
| worker_id: int, | |||||
| rng: random.Random, | |||||
| triples: List[Tuple[int, int, int]], | |||||
| out_adj: Dict[int, List[Tuple[int, int]]], | |||||
| in_adj: Dict[int, List[Tuple[int, int]]], | |||||
| colours: List[int], | |||||
| n_ent: int, | |||||
| n_rel: int, | |||||
| depth: int, | |||||
| c: float, | |||||
| lo: float, | |||||
| hi: float, | |||||
| max_iters: int, | |||||
| state_lock: threading.Lock, | |||||
| stop_event: threading.Event, | |||||
| ): | |||||
| """One thread: mutate the *shared* structures until success/stop.""" | |||||
| for it in range(max_iters): | |||||
| if stop_event.is_set(): | |||||
| return # someone else finished ─ exit early | |||||
| with state_lock: # protect the shared graph | |||||
| # if lo <= c <= hi: | |||||
| # logging.info( | |||||
| # "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)", | |||||
| # worker_id, | |||||
| # it, | |||||
| # c, | |||||
| # len(triples), | |||||
| # ) | |||||
| # stop_event.set() | |||||
| # return | |||||
| # Choose and apply one edit ----------------------------- | |||||
| if c < lo: # need ↑ CREC | |||||
| if rng.random() < 0.5: | |||||
| rem_det_edge(triples, out_adj, in_adj, colours, depth, rng) | |||||
| else: | |||||
| add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng) | |||||
| elif rng.random() < 0.5: | |||||
| rem_div_edge(triples, out_adj, in_adj, colours, depth, rng) | |||||
| else: | |||||
| add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng) | |||||
| logging.warning("[worker %d] reached max_iters", worker_id) | |||||
| stop_event.set() | |||||
| return | |||||
| # -------------------------------------------------------------------- | |||||
| # Public API -------------------------------------------------------- | |||||
| # -------------------------------------------------------------------- | |||||
| def fast_tune_crec( | |||||
| triples: List, | |||||
| colours: List, | |||||
| n_ent: int, | |||||
| n_rel: int, | |||||
| depth: int, | |||||
| c: float, | |||||
| lo: float, | |||||
| hi: float, | |||||
| max_iters: int = 1000, | |||||
| max_workers: int | None = None, | |||||
| seeds: List[int] | None = None, | |||||
| ) -> List[Tuple[int, int, int]]: | |||||
| """Tune WL‑CREC with *shared* triples using multiple threads. | |||||
| Returns the **same list instance** that was passed in – already | |||||
| modified in place by the winning thread. | |||||
| """ | |||||
| if max_workers is None: | |||||
| # max_workers = os.cpu_count() or 4 | |||||
| max_workers = 4 | |||||
| if seeds is None: | |||||
| seeds = [42 + i for i in range(max_workers)] | |||||
| assert len(seeds) >= max_workers, "Need at least one seed per worker" | |||||
| # Prepare adjacency once (shared) -------------------------------- | |||||
| out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||||
| in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||||
| for h, r, t in triples: | |||||
| out_adj[h].append((r, t)) | |||||
| in_adj[t].append((r, h)) | |||||
| state_lock = threading.Lock() | |||||
| stop_event = threading.Event() | |||||
| logging.info( | |||||
| "Launching %d threads on shared triples (target %.3f–%.3f)", | |||||
| max_workers, | |||||
| lo, | |||||
| hi, | |||||
| ) | |||||
| threads: List[threading.Thread] = [] | |||||
| for wid in range(max_workers): | |||||
| t = threading.Thread( | |||||
| name=f"tune‑crec‑worker‑{wid}", | |||||
| target=_worker, | |||||
| kwargs=dict( | |||||
| worker_id=wid, | |||||
| rng=random.Random(seeds[wid]), | |||||
| triples=triples, | |||||
| out_adj=out_adj, | |||||
| in_adj=in_adj, | |||||
| colours=colours, | |||||
| n_ent=n_ent, | |||||
| n_rel=n_rel, | |||||
| depth=depth, | |||||
| c=c, | |||||
| lo=lo, | |||||
| hi=hi, | |||||
| max_iters=max_iters, | |||||
| state_lock=state_lock, | |||||
| stop_event=stop_event, | |||||
| ), | |||||
| daemon=False, | |||||
| ) | |||||
| threads.append(t) | |||||
| t.start() | |||||
| for t in threads: | |||||
| t.join() | |||||
| if not stop_event.is_set(): | |||||
| raise RuntimeError("No thread converged – try increasing max_iters or widen band") | |||||
| return triples |
| from __future__ import annotations | |||||
| import logging | |||||
| import random | |||||
| from collections import defaultdict, deque | |||||
| from math import isclose | |||||
| from typing import List | |||||
| import torch | |||||
| from data.kg_dataset import KGDataset | |||||
| from metrics.wlcrec import WLCREC | |||||
| from tools import get_pretty_logger | |||||
| logging = get_pretty_logger(__name__) | |||||
| # -------------------------------------------------------------------- | |||||
| # Radius-k patch sampler | |||||
| # -------------------------------------------------------------------- | |||||
| def radius_patch_sample( | |||||
| all_triples: List, | |||||
| ratio: float, | |||||
| radius: int, | |||||
| rng: random.Random, | |||||
| ) -> List: | |||||
| """ | |||||
| Take ~ratio fraction of triples by picking random seeds and | |||||
| keeping *all* triples within ≤ radius hops (undirected) of them. | |||||
| """ | |||||
| n_total = len(all_triples) | |||||
| target = int(ratio * n_total) | |||||
| # Build entity → triple-indices adjacency for BFS | |||||
| ent2trip = defaultdict(list) | |||||
| for idx, (h, _, t) in enumerate(all_triples): | |||||
| ent2trip[h].append(idx) | |||||
| ent2trip[t].append(idx) | |||||
| chosen_idx: set[int] = set() | |||||
| visited_ent = set() | |||||
| while len(chosen_idx) < target: | |||||
| seed_idx = rng.randrange(n_total) | |||||
| if seed_idx in chosen_idx: | |||||
| continue | |||||
| # BFS over entities | |||||
| queue = deque() | |||||
| for e in all_triples[seed_idx][::2]: # subject & object | |||||
| if e not in visited_ent: | |||||
| queue.append((e, 0)) | |||||
| visited_ent.add(e) | |||||
| while queue: | |||||
| ent, dist = queue.popleft() | |||||
| for tidx in ent2trip[ent]: | |||||
| if tidx not in chosen_idx: | |||||
| chosen_idx.add(tidx) | |||||
| if dist < radius: | |||||
| for tidx in ent2trip[ent]: | |||||
| h, _, t = all_triples[tidx] | |||||
| for nb in (h, t): | |||||
| if nb not in visited_ent: | |||||
| visited_ent.add(nb) | |||||
| queue.append((nb, dist + 1)) | |||||
| # guard against infinite loop on very small ratio | |||||
| if isclose(ratio, 0.0) and chosen_idx: | |||||
| break | |||||
| return [all_triples[i] for i in chosen_idx] | |||||
| # -------------------------------------------------------------------- | |||||
| # Main driver: repeated sampling until WL-CREC band hit | |||||
| # -------------------------------------------------------------------- | |||||
| def find_sample( | |||||
| all_triples: List, | |||||
| n_ent: int, | |||||
| n_rel: int, | |||||
| depth: int, | |||||
| ratio: float, | |||||
| radius: int, | |||||
| lower: float, | |||||
| upper: float, | |||||
| max_tries: int, | |||||
| seed: int, | |||||
| ) -> List: | |||||
| for trial in range(1, max_tries + 1): | |||||
| rng = random.Random(seed + trial) # fresh randomness each try | |||||
| sample = radius_patch_sample(all_triples, ratio, radius, rng) | |||||
| sample_ds = KGDataset( | |||||
| triples_factory=None, | |||||
| triples=torch.tensor(sample, dtype=torch.long), | |||||
| num_entities=n_ent, | |||||
| num_relations=n_rel, | |||||
| split="train", | |||||
| ) | |||||
| wl_crec = WLCREC(sample_ds) | |||||
| crec_val = wl_crec.compute(depth)[-2] | |||||
| logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val) | |||||
| if lower <= crec_val <= upper: | |||||
| logging.info("Success after %d tries.", trial) | |||||
| return sample | |||||
| raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.") |
| from __future__ import annotations | |||||
| import logging | |||||
| import random | |||||
| from math import log | |||||
| from typing import Dict, List, Tuple | |||||
| import torch | |||||
| from torch import Tensor | |||||
| from tools import get_pretty_logger | |||||
| logging = get_pretty_logger(__name__) | |||||
| Entity = int | |||||
| Relation = int | |||||
| Triple = Tuple[Entity, Relation, Entity] | |||||
| Color = int | |||||
| # --------------------------------------------------------------------- | |||||
| # Weisfeiler–Lehman colouring (GPU friendly) | |||||
| # --------------------------------------------------------------------- | |||||
| def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]: | |||||
| """ | |||||
| GPU implementation of classic Weisfeiler‑Lehman refinement. | |||||
| 1. Each node starts with colour 0. | |||||
| 2. At iteration h we hash the multiset of | |||||
| ⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h)) | |||||
| signatures into new colour ids. | |||||
| 3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and | |||||
| immaterial for CREC (only *relative* colours matter). | |||||
| """ | |||||
| h, r, t = triples.unbind(dim=1) # (|T|,) | |||||
| colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0 | |||||
| # pre‑compute 64‑bit mix constants once | |||||
| mix_r = ( | |||||
| torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long) | |||||
| * 1_146_189_683_093_321_123 | |||||
| ) | |||||
| mix_c = 8_636_673_225_737_527_201 | |||||
| for _ in range(depth): | |||||
| prev = colours[-1] # (|V|,) | |||||
| # signatures for outgoing edges | |||||
| sig_out = mix_r[r] ^ (prev[t] * mix_c) | |||||
| # signatures for incoming edges | |||||
| sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15 | |||||
| # bucket signatures back to the source/target nodes | |||||
| col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out) | |||||
| col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in) | |||||
| # combine ↓ and ↑ multiset hashes with current colour | |||||
| raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1) | |||||
| # re‑map to dense consecutive ids with torch.unique | |||||
| uniq, new = torch.unique(raw, sorted=True, return_inverse=True) | |||||
| colours.append(new) | |||||
| return colours # length = depth+1, each (|V|,) long | |||||
| # --------------------------------------------------------------------- | |||||
| # WL‑CREC metric | |||||
| # --------------------------------------------------------------------- | |||||
| def wl_crec_gpu( | |||||
| triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device | |||||
| ) -> float: | |||||
| log_n = log(max(n_ent, 2)) | |||||
| # ------- pattern‑diversity C ------------------------------------- | |||||
| C = 0.0 | |||||
| for c in colours: # (|V|,) | |||||
| # histogram via bincount on GPU | |||||
| hist = torch.bincount(c).float() | |||||
| p = hist / n_ent | |||||
| C += -(p * torch.log(p.clamp_min(1e-30))).sum().item() | |||||
| C /= (depth + 1) * log_n | |||||
| # ------- residual entropy H_c ------------------------------------ | |||||
| h, r, t = triples.unbind(dim=1) | |||||
| sigma, tau = colours[depth][h], colours[depth][t] # (|T|,) | |||||
| # encode (sigma,tau) pairs into a single 64‑bit key | |||||
| sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64) | |||||
| # total count per signature | |||||
| sig_unique, sig_inv, sig_counts = torch.unique( | |||||
| sig_keys, return_inverse=True, return_counts=True, sorted=False | |||||
| ) | |||||
| # build 2‑D contingency table counts[(sigma,tau), r] | |||||
| m = triples.size(0) | |||||
| keys_2d = sig_inv * n_rel + r | |||||
| _, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False) | |||||
| # rel_counts is aligned with the *compact* key list; rebuild dense tensor | |||||
| rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long) | |||||
| rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long)) | |||||
| # conditional entropy per (sigma,tau) | |||||
| p_r = rc_dense.float() / sig_counts.unsqueeze(1).float() | |||||
| inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,) | |||||
| Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2)) | |||||
| return C * Hc | |||||
| # --------------------------------------------------------------------- | |||||
| # delta‑entropy helpers (still CPU side for clarity; cost is negligible) | |||||
| # --------------------------------------------------------------------- | |||||
| # The deterministic delta‑formulas depend on small dictionaries and are evaluated | |||||
| # on ≤256 candidate edges each iteration – copy them from the original script. | |||||
| inner_entropy = ... # unchanged | |||||
| delta_h_cond_remove = ... # unchanged | |||||
| delta_h_cond_add = ... # unchanged | |||||
| # --------------------------------------------------------------------- | |||||
| # Greedy delta‑search driver | |||||
| # --------------------------------------------------------------------- | |||||
| def greedy_delta_tune_gpu( | |||||
| triples: List[Triple], | |||||
| n_ent: int, | |||||
| n_rel: int, | |||||
| depth: int, | |||||
| lower: float, | |||||
| upper: float, | |||||
| device: torch.device, | |||||
| max_iters: int = 40_000, | |||||
| sample_size: int = 256, | |||||
| seed: int = 0, | |||||
| ) -> List[Triple]: | |||||
| # book‑keeping in Python lists (cheap); heavy math in torch on GPU | |||||
| rng = random.Random(seed) | |||||
| # cache torch version of triples for fast metric eval | |||||
| def triples_to_tensor(ts: List[Triple]) -> Tensor: | |||||
| return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False) | |||||
| for it in range(max_iters): | |||||
| tri_tensor = triples_to_tensor(triples) | |||||
| colours = wl_colours_gpu(tri_tensor, n_ent, depth, device) | |||||
| crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device) | |||||
| if lower <= crec <= upper: | |||||
| logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples)) | |||||
| return triples | |||||
| # ---------------------------------------------------------------- | |||||
| # build signature statistics (CPU, cheap: O(|T|)) | |||||
| # ---------------------------------------------------------------- | |||||
| depth_col = colours[depth].cpu() | |||||
| sig_cnt: Dict[Tuple[int, int], int] = {} | |||||
| rel_cnt: Dict[Tuple[int, int, int], int] = {} | |||||
| for h_idx, (h, r, t) in enumerate(triples): | |||||
| sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||||
| sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1 | |||||
| rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1 | |||||
| det_edges, div_edges = [], [] | |||||
| for idx, (h, r, t) in enumerate(triples): | |||||
| sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||||
| if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1: | |||||
| det_edges.append(idx) | |||||
| if sigma != tau: | |||||
| div_edges.append(idx) | |||||
| # ---------------------------------------------------------------- | |||||
| # candidate generation + best edit selection | |||||
| # ---------------------------------------------------------------- | |||||
| total = len(triples) | |||||
| target_high = crec < lower # need to raise WL‑CREC? | |||||
| candidates = [] | |||||
| if target_high and det_edges: | |||||
| rng.shuffle(det_edges) | |||||
| for idx in det_edges[:sample_size]: | |||||
| h, r, t = triples[idx] | |||||
| sig = (int(depth_col[h]), int(depth_col[t])) | |||||
| delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||||
| if delta > 0: | |||||
| candidates.append(("remove", idx, delta)) | |||||
| elif not target_high and det_edges: | |||||
| rng.shuffle(det_edges) | |||||
| for idx in det_edges[:sample_size]: | |||||
| h, r, t = triples[idx] | |||||
| sig = (int(depth_col[h]), int(depth_col[t])) | |||||
| delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||||
| if delta < 0: | |||||
| candidates.append(("add", (h, r, t), delta)) | |||||
| # fall‑back heuristics | |||||
| if not candidates: | |||||
| if target_high and div_edges: | |||||
| idx = rng.choice(div_edges) | |||||
| candidates.append(("remove", idx, 1e-9)) | |||||
| elif not target_high: | |||||
| idx = rng.choice(det_edges) if det_edges else rng.randrange(total) | |||||
| h, r, t = triples[idx] | |||||
| candidates.append(("add", (h, r, t), -1e-9)) | |||||
| best = ( | |||||
| max(candidates, key=lambda x: x[2]) | |||||
| if target_high | |||||
| else min(candidates, key=lambda x: x[2]) | |||||
| ) | |||||
| # apply edit | |||||
| if best[0] == "remove": | |||||
| triples.pop(best[1]) | |||||
| else: | |||||
| triples.append(best[1]) | |||||
| if (it + 1) % 1_000 == 0: | |||||
| logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples)) | |||||
| raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.") |
| import torch | |||||
| from models.base_model import ERModel | |||||
| def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |||||
| return (scores > scores[target]).sum() + 1 | |||||
| def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor: | |||||
| h, r, t = triples[:, 0], triples[:, 1], triples[:, 2] | |||||
| E = model.num_entities | |||||
| ranks = [] | |||||
| with torch.no_grad(): | |||||
| for hi, ri, ti in zip(h, r, t): | |||||
| tails = torch.arange(E, device=triples.device) | |||||
| triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1) | |||||
| scores = model(triples) | |||||
| ranks.append(_rank(scores, ti)) | |||||
| return torch.tensor(ranks, device=triples.device, dtype=torch.float32) | |||||
| def mrr(r: torch.Tensor) -> torch.Tensor: | |||||
| return (1 / r).mean() | |||||
| def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor: | |||||
| return (r <= k).float().mean() |
| from __future__ import annotations | |||||
| import math | |||||
| from collections import Counter, defaultdict | |||||
| from typing import TYPE_CHECKING, Dict, List, Tuple | |||||
| import torch | |||||
| from .base_metric import BaseMetric | |||||
| if TYPE_CHECKING: | |||||
| from data.kg_dataset import KGDataset | |||||
| def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||||
| """Return Shannon entropy (nats) of a colour distribution of length *n*.""" | |||||
| ent = 0.0 | |||||
| for cnt in counter.values(): | |||||
| if cnt: | |||||
| p = cnt / n | |||||
| ent -= p * math.log(p) | |||||
| return ent | |||||
| def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||||
| """Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies.""" | |||||
| if n <= 1: | |||||
| return 0.0, 0.0 | |||||
| H_plus_1 = len(entropies) | |||||
| log_n = math.log(n) | |||||
| c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1 | |||||
| c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1 | |||||
| return c_ratio, c_nwlec | |||||
| def _relation_families( | |||||
| triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9 | |||||
| ) -> torch.Tensor: | |||||
| """Return a LongTensor mapping each relation id → family id. | |||||
| Two relations *r, r_inv* are put into the same family when **at least | |||||
| `thresh` fraction** of the edges labelled *r* have a **single** reverse edge | |||||
| labelled *r_inv*. | |||||
| """ | |||||
| assert 0.0 < thresh <= 1.0 | |||||
| # Build map (h,t) -> list of relations on that edge direction. | |||||
| edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list) | |||||
| for h, r, t in triples.tolist(): | |||||
| edge2rels[(h, t)].append(int(r)) | |||||
| # Count reciprocal co‑occurrences (r, r_rev). | |||||
| pair_counts: Dict[Tuple[int, int], int] = defaultdict(int) | |||||
| rel_totals: List[int] = [0] * num_relations | |||||
| for (h, t), rels in edge2rels.items(): | |||||
| rev_rels = edge2rels.get((t, h)) | |||||
| if not rev_rels: | |||||
| continue | |||||
| for r in rels: | |||||
| rel_totals[r] += 1 | |||||
| for r_rev in rev_rels: | |||||
| pair_counts[(r, r_rev)] += 1 | |||||
| # Proposed inverse mapping r -> r_inv. | |||||
| proposed: List[int | None] = [None] * num_relations | |||||
| for (r, r_rev), cnt in pair_counts.items(): | |||||
| if cnt == 0: | |||||
| continue | |||||
| if cnt / rel_totals[r] >= thresh: | |||||
| # majority of r's edges are reversed by r_rev | |||||
| proposed[r] = r_rev | |||||
| # Build family ids (transitive closure is overkill; data are simple). | |||||
| family = list(range(num_relations)) | |||||
| for r, rinv in enumerate(proposed): | |||||
| if rinv is not None: | |||||
| root = min(family[r], family[rinv]) | |||||
| family[r] = family[rinv] = root | |||||
| return torch.tensor(family, dtype=torch.long) | |||||
| def _conditional_relation_entropy( | |||||
| colours: List[int], | |||||
| triples: torch.Tensor, | |||||
| family_map: torch.Tensor, | |||||
| ) -> float: | |||||
| counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int)) | |||||
| m = int(triples.size(0)) | |||||
| fam = family_map # alias for speed | |||||
| for h, r, t in triples.tolist(): | |||||
| # key = (colours[h], colours[t]) # ordered key (direction‑aware) | |||||
| key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic) | |||||
| counts[key][int(fam[r])] += 1 # ▼ use family id | |||||
| h_cond = 0.0 | |||||
| for rel_counts in counts.values(): | |||||
| total_s = sum(rel_counts.values()) | |||||
| P_s = total_s / m | |||||
| inv_total = 1.0 / total_s | |||||
| for cnt in rel_counts.values(): | |||||
| p_rs = cnt * inv_total | |||||
| h_cond -= P_s * p_rs * math.log(p_rs) | |||||
| return h_cond | |||||
| class WLCREC(BaseMetric): | |||||
| """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||||
| def __init__(self, dataset: KGDataset) -> None: | |||||
| super().__init__(dataset) | |||||
| def compute( | |||||
| self, | |||||
| H: int = 3, | |||||
| cond_h: int = 1, | |||||
| inv_thresh: float = 0.9, | |||||
| return_full: bool = False, | |||||
| progress: bool = True, | |||||
| ) -> float | Tuple[float, List[float]]: | |||||
| """Compute WL-entropy scores and the composite difficulty metric. | |||||
| Parameters | |||||
| ---------- | |||||
| dataset | |||||
| Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``. | |||||
| H | |||||
| WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers. | |||||
| return_full | |||||
| Also return the list of per-layer entropies. | |||||
| progress | |||||
| Print textual progress bar. | |||||
| Returns | |||||
| ------- | |||||
| avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC | |||||
| *Six* scalars (floats). If ``return_full=True`` an extra list of | |||||
| layer entropies is appended. | |||||
| """ | |||||
| # compute relation family map once | |||||
| family_map = _relation_families( | |||||
| self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh | |||||
| ) | |||||
| # ─── WL iterations ──────────────────────────────────────────────────── | |||||
| n = self.dataset.num_entities | |||||
| colour: List[int] = [1] * n # colour 1 for everyone | |||||
| entropies: List[float] = [] | |||||
| colours_per_h: List[List[int]] = [] | |||||
| def _print_prog(i: int) -> None: | |||||
| if progress: | |||||
| print(f"\rWL iteration {i}/{H}", end="", flush=True) | |||||
| for h in range(H + 1): | |||||
| _print_prog(h) | |||||
| colours_per_h.append(colour.copy()) | |||||
| # entropy of current colouring | |||||
| freq = Counter(colour) | |||||
| entropies.append(_entropy_from_counter(freq, n)) | |||||
| if h == H: | |||||
| break | |||||
| # refine colours | |||||
| bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||||
| next_colour: List[int] = [0] * n | |||||
| for v in range(n): | |||||
| T: List[Tuple[int, int, int]] = [] | |||||
| for r, t in self.out_adj[v]: | |||||
| T.append((r, 0, colour[t])) # outgoing | |||||
| for r, h_ in self.in_adj[v]: | |||||
| T.append((r, 1, colour[h_])) # incoming | |||||
| T.sort() | |||||
| key = (colour[v], tuple(T)) | |||||
| if key not in bucket: | |||||
| bucket[key] = len(bucket) + 1 | |||||
| next_colour[v] = bucket[key] | |||||
| colour = next_colour | |||||
| _print_prog(H) | |||||
| if progress: | |||||
| print() | |||||
| # ─── Normalised diversity ──────────────────────────────────────────── | |||||
| avg_entropy = sum(entropies) / len(entropies) | |||||
| c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||||
| # ─── Residual relation entropy ─────────────────────────────────────── | |||||
| h_cond = _conditional_relation_entropy( | |||||
| colours_per_h[cond_h], self.dataset.triples, family_map | |||||
| ) | |||||
| # ─── Composite difficulty - Eq. (1) ────────────────────────────────── | |||||
| d_ratio = c_ratio * h_cond | |||||
| d_nwlec = c_nwlec * h_cond | |||||
| result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec) | |||||
| if return_full: | |||||
| result += (entropies,) | |||||
| return result | |||||
| # -------------------------------------------------------------------- | |||||
| # Weisfeiler–Lehman colours (unchanged) | |||||
| # -------------------------------------------------------------------- | |||||
| def wl_colours(self, depth: int) -> List[List[int]]: | |||||
| triples = self.dataset.triples.tolist() | |||||
| out_adj, in_adj = defaultdict(list), defaultdict(list) | |||||
| for h, r, t in triples: | |||||
| out_adj[h].append((r, t)) | |||||
| in_adj[t].append((r, h)) | |||||
| n_ent = self.dataset.num_entities | |||||
| colours = [[0] * n_ent] # round-0 | |||||
| for h in range(1, depth + 1): | |||||
| prev, nxt, sig2c = colours[-1], [0] * n_ent, {} | |||||
| fresh = 0 | |||||
| for v in range(n_ent): | |||||
| neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||||
| ("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||||
| ] | |||||
| neigh.sort() | |||||
| sig = (prev[v], tuple(neigh)) | |||||
| if sig not in sig2c: | |||||
| sig2c[sig] = fresh | |||||
| fresh += 1 | |||||
| nxt[v] = sig2c[sig] | |||||
| colours.append(nxt) | |||||
| return colours |
| from __future__ import annotations | |||||
| import math | |||||
| from collections import Counter | |||||
| from typing import TYPE_CHECKING, Dict, List, Tuple | |||||
| from .base_metric import BaseMetric | |||||
| if TYPE_CHECKING: | |||||
| from data.kg_dataset import KGDataset | |||||
| def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||||
| """Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts.""" | |||||
| ent = 0.0 | |||||
| for count in counter.values(): | |||||
| if count: # avoid log(0) | |||||
| p = count / n | |||||
| ent -= p * math.log(p) | |||||
| return ent | |||||
| def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||||
| """Return (C_ratio, C_NWLEC) given *per-iteration* entropies.""" | |||||
| if n <= 1: | |||||
| # Degenerate graph. | |||||
| return 0.0, 0.0 | |||||
| H_plus_1 = len(entropies) | |||||
| # Eq. (1): simple ratio normalisation. | |||||
| c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1 | |||||
| # Eq. (2): effective-colour NWLEC. | |||||
| k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies] | |||||
| c_nwlec = sum(k_terms) / H_plus_1 | |||||
| return c_ratio, c_nwlec | |||||
| class WLEC(BaseMetric): | |||||
| """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||||
| def __init__(self, dataset: KGDataset) -> None: | |||||
| super().__init__(dataset) | |||||
| def compute( | |||||
| self, | |||||
| H: int = 3, | |||||
| *, | |||||
| return_full: bool = False, | |||||
| progress: bool = True, | |||||
| ) -> float | Tuple[float, List[float]]: | |||||
| """Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC). | |||||
| Parameters | |||||
| ---------- | |||||
| dataset: | |||||
| Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``), | |||||
| ``num_entities`` and ``num_relations``. | |||||
| H: | |||||
| Number of refinement iterations **after** the initial colouring (so the | |||||
| colours are updated **H** times and the entropy is measured **H+1** times). | |||||
| return_full: | |||||
| If *True*, additionally return the list of entropies for each depth. | |||||
| progress: | |||||
| Whether to print a mini progress bar (no external dependency). | |||||
| Returns | |||||
| ------- | |||||
| average_entropy or (average_entropy, entropies) | |||||
| Average of the stored entropies, and optionally the list itself. | |||||
| """ | |||||
| n = int(self.num_entities) | |||||
| if n == 0: | |||||
| msg = "Dataset appears to contain zero entities." | |||||
| raise ValueError(msg) | |||||
| # Colour assignment - we keep it as a *list* of ints for cheap hashing. | |||||
| # Start with colour 1 for every entity (index 0 unused). | |||||
| colour: List[int] = [1] * n | |||||
| entropies: List[float] = [] | |||||
| # Optional poor-man's progress bar. | |||||
| def _print_prog(i: int) -> None: | |||||
| if progress: | |||||
| print(f"\rIteration {i}/{H}", end="", flush=True) | |||||
| for h in range(H + 1): | |||||
| _print_prog(h) | |||||
| # ------------------------------------------------------------------ | |||||
| # Step 1: measure entropy of current colouring ---------------------- | |||||
| # ------------------------------------------------------------------ | |||||
| freq = Counter(colour) | |||||
| ent = _entropy_from_counter(freq, n) | |||||
| entropies.append(ent) | |||||
| if h == H: | |||||
| break # We have reached the requested depth - stop here. | |||||
| # ------------------------------------------------------------------ | |||||
| # Step 2: create refined colours ----------------------------------- | |||||
| # ------------------------------------------------------------------ | |||||
| bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||||
| next_colour: List[int] = [0] * n | |||||
| for v in range(n): | |||||
| # Build the multiset T containing outgoing and incoming features. | |||||
| T: List[Tuple[int, int, int]] = [] | |||||
| for r, t in self.out_adj[v]: | |||||
| T.append((r, 0, colour[t])) # 0 = outgoing (\u2193) | |||||
| for r, h_ in self.in_adj[v]: | |||||
| T.append((r, 1, colour[h_])) # 1 = incoming (\u2191) | |||||
| T.sort() # canonical ordering | |||||
| key = (colour[v], tuple(T)) | |||||
| if key not in bucket: | |||||
| bucket[key] = len(bucket) + 1 # new colour ID (start at 1) | |||||
| next_colour[v] = bucket[key] | |||||
| colour = next_colour # move to next iteration | |||||
| _print_prog(H) | |||||
| if progress: | |||||
| print() # newline after progress bar | |||||
| average_entropy = sum(entropies) / len(entropies) | |||||
| c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||||
| if return_full: | |||||
| return average_entropy, entropies, c_ratio, c_nwlec | |||||
| return average_entropy, c_ratio, c_nwlec |
| from __future__ import annotations | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import TYPE_CHECKING | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.nn.functional import log_softmax, softmax | |||||
| if TYPE_CHECKING: | |||||
| from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target | |||||
| class ERModel(nn.Module, ABC): | |||||
| """Base class for knowledge graph models.""" | |||||
| def __init__(self, num_entities: int, num_relations: int, dim: int) -> None: | |||||
| super().__init__() | |||||
| self.num_entities = num_entities | |||||
| self.num_relations = num_relations | |||||
| self.dim = dim | |||||
| self.inner_model = None | |||||
| @property | |||||
| def device(self) -> torch.device: | |||||
| """Return the device of the model.""" | |||||
| return next(self.inner_model.parameters()).device | |||||
| @abstractmethod | |||||
| def reset_parameters(self, *args, **kwargs) -> None: | |||||
| """Reset the parameters of the model.""" | |||||
| def predict( | |||||
| self, | |||||
| hrt_batch: MappedTriples, | |||||
| target: Target, | |||||
| full_batch: bool = True, | |||||
| ids: LongTensor | None = None, | |||||
| **kwargs, | |||||
| ) -> FloatTensor: | |||||
| if not self.inner_model: | |||||
| msg = ( | |||||
| "Inner model is not set. Please initialize the inner model before calling predict." | |||||
| ) | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| return self.inner_model.predict( | |||||
| hrt_batch=hrt_batch, | |||||
| target=target, | |||||
| full_batch=full_batch, | |||||
| ids=ids, | |||||
| **kwargs, | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| triples: MappedTriples, | |||||
| slice_size: int | None = None, | |||||
| slice_dim: int = 0, | |||||
| *, | |||||
| mode: InductiveMode | None, | |||||
| ) -> FloatTensor: | |||||
| h_indices = triples[:, 0] | |||||
| r_indices = triples[:, 1] | |||||
| t_indices = triples[:, 2] | |||||
| if not self.inner_model: | |||||
| msg = ( | |||||
| "Inner model is not set. Please initialize the inner model before calling forward." | |||||
| ) | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| return self.inner_model.forward( | |||||
| h_indices=h_indices, | |||||
| r_indices=r_indices, | |||||
| t_indices=t_indices, | |||||
| slice_size=slice_size, | |||||
| slice_dim=slice_dim, | |||||
| mode=mode, | |||||
| ) | |||||
| @torch.inference_mode() | |||||
| def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r) | |||||
| self, | |||||
| hr_batch: LongTensor, # (B, 2) : (h,r) | |||||
| *, | |||||
| slice_size: int | None = None, | |||||
| mode: InductiveMode | None = None, | |||||
| log: bool = False, | |||||
| ) -> FloatTensor: # (B, |E|) | |||||
| scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode) | |||||
| return log_softmax(scores, -1) if log else softmax(scores, -1) | |||||
| # ----------------------------------------------------------------------- # | |||||
| # (1) HEAD-given-(relation, tail) → pθ(h | r,t) # | |||||
| # ----------------------------------------------------------------------- # | |||||
| @torch.inference_mode() | |||||
| def head_distribution( | |||||
| self, | |||||
| rt_batch: LongTensor, # (B, 2) : (r,t) | |||||
| *, | |||||
| slice_size: int | None = None, | |||||
| mode: InductiveMode | None = None, | |||||
| log: bool = False, | |||||
| ) -> FloatTensor: # (B, |E|) | |||||
| scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode) | |||||
| return log_softmax(scores, -1) if log else softmax(scores, -1) | |||||
| # ----------------------------------------------------------------------- # | |||||
| # (2) RELATION-given-(head, tail) → pθ(r | h,t) # | |||||
| # ----------------------------------------------------------------------- # | |||||
| @torch.inference_mode() | |||||
| def relation_distribution( | |||||
| self, | |||||
| ht_batch: LongTensor, # (B, 2) : (h,t) | |||||
| *, | |||||
| slice_size: int | None = None, | |||||
| mode: InductiveMode | None = None, | |||||
| log: bool = False, | |||||
| ) -> FloatTensor: # (B, |R|) | |||||
| scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode) | |||||
| return log_softmax(scores, -1) if log else softmax(scores, -1) |
| from __future__ import annotations | |||||
| from typing import TYPE_CHECKING | |||||
| import torch | |||||
| from pykeen.models import TransE as PyKEENTransE | |||||
| from torch import nn | |||||
| from models.base_model import ERModel | |||||
| if TYPE_CHECKING: | |||||
| from pykeen.typing import MappedTriples | |||||
| class TransE(ERModel): | |||||
| """ | |||||
| TransE model for knowledge graph embedding. | |||||
| score = ||h + r - t|| | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| num_entities: int, | |||||
| num_relations: int, | |||||
| triples_factory: MappedTriples, | |||||
| dim: int = 200, | |||||
| sf_norm: int = 1, | |||||
| p_norm: bool = True, | |||||
| margin: float | None = None, | |||||
| epsilon: float | None = None, | |||||
| device: torch.device = torch.device("cpu"), | |||||
| ) -> None: | |||||
| super().__init__(num_entities, num_relations, dim) | |||||
| self.dim = dim | |||||
| self.sf_norm = sf_norm | |||||
| self.p_norm = p_norm | |||||
| self.margin = margin | |||||
| self.epsilon = epsilon | |||||
| self.inner_model = PyKEENTransE( | |||||
| triples_factory=triples_factory, | |||||
| embedding_dim=dim, | |||||
| scoring_fct_norm=sf_norm, | |||||
| power_norm=p_norm, | |||||
| random_seed=42, | |||||
| ).to(device) | |||||
| def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||||
| # Parameter initialization | |||||
| if margin is None or epsilon is None: | |||||
| # If no margin/epsilon are provided, use Xavier uniform initialization | |||||
| nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||||
| nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||||
| else: | |||||
| # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||||
| self.embedding_range = nn.Parameter( | |||||
| torch.Tensor([(margin + epsilon) / self.dim]), | |||||
| requires_grad=False, | |||||
| ) | |||||
| nn.init.uniform_( | |||||
| tensor=self.ent_embeddings.weight.data, | |||||
| a=-self.embedding_range.item(), | |||||
| b=self.embedding_range.item(), | |||||
| ) | |||||
| nn.init.uniform_( | |||||
| tensor=self.rel_embeddings.weight.data, | |||||
| a=-self.embedding_range.item(), | |||||
| b=self.embedding_range.item(), | |||||
| ) |
| from __future__ import annotations | |||||
| from typing import TYPE_CHECKING | |||||
| import torch | |||||
| from pykeen.losses import MarginRankingLoss | |||||
| from pykeen.models import TransH as PyKEENTransH | |||||
| from pykeen.regularizers import LpRegularizer | |||||
| from torch import nn | |||||
| from models.base_model import ERModel | |||||
| if TYPE_CHECKING: | |||||
| from pykeen.typing import MappedTriples | |||||
| class TransH(ERModel): | |||||
| """ | |||||
| TransH model for knowledge graph embedding. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| num_entities: int, | |||||
| num_relations: int, | |||||
| triples_factory: MappedTriples, | |||||
| dim: int = 200, | |||||
| p_norm: bool = True, | |||||
| margin: float | None = None, | |||||
| epsilon: float | None = None, | |||||
| device: torch.device = torch.device("cpu"), | |||||
| **kwargs, | |||||
| ) -> None: | |||||
| super().__init__(num_entities, num_relations, dim) | |||||
| self.dim = dim | |||||
| self.p_norm = p_norm | |||||
| self.margin = margin | |||||
| self.epsilon = epsilon | |||||
| self.inner_model = PyKEENTransH( | |||||
| triples_factory=triples_factory, | |||||
| embedding_dim=dim, | |||||
| scoring_fct_norm=1, | |||||
| loss=self.loss, | |||||
| power_norm=p_norm, | |||||
| entity_initializer="xavier_uniform_", # good default | |||||
| relation_initializer="xavier_uniform_", | |||||
| random_seed=42, | |||||
| ).to(device) | |||||
| @property | |||||
| def loss(self) -> torch.Tensor: | |||||
| return MarginRankingLoss(margin=self.margin, reduction="mean") | |||||
| def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||||
| # Parameter initialization | |||||
| if margin is None or epsilon is None: | |||||
| # If no margin/epsilon are provided, use Xavier uniform initialization | |||||
| nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||||
| nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||||
| else: | |||||
| # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||||
| self.embedding_range = nn.Parameter( | |||||
| torch.Tensor([(margin + epsilon) / self.dim]), | |||||
| requires_grad=False, | |||||
| ) | |||||
| nn.init.uniform_( | |||||
| tensor=self.ent_embeddings.weight.data, | |||||
| a=-self.embedding_range.item(), | |||||
| b=self.embedding_range.item(), | |||||
| ) | |||||
| nn.init.uniform_( | |||||
| tensor=self.rel_embeddings.weight.data, | |||||
| a=-self.embedding_range.item(), | |||||
| b=self.embedding_range.item(), | |||||
| ) |
| from __future__ import annotations | |||||
| from typing import TYPE_CHECKING | |||||
| import torch | |||||
| from pykeen.losses import MarginRankingLoss | |||||
| from pykeen.models import TransR as PyKEENTransR | |||||
| from torch import nn | |||||
| from models.base_model import ERModel | |||||
| if TYPE_CHECKING: | |||||
| from pykeen.typing import MappedTriples | |||||
| class TransR(ERModel): | |||||
| """ | |||||
| TransR model for knowledge graph embedding. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| num_entities: int, | |||||
| num_relations: int, | |||||
| triples_factory: MappedTriples, | |||||
| entity_dim: int = 200, | |||||
| relation_dim: int = 30, | |||||
| p_norm: bool = True, | |||||
| p_norm_value: int = 2, | |||||
| margin: float | None = None, | |||||
| epsilon: float | None = None, | |||||
| device: torch.device = torch.device("cpu"), | |||||
| **kwargs, | |||||
| ) -> None: | |||||
| super().__init__(num_entities, num_relations, entity_dim) | |||||
| self.entity_dim = entity_dim | |||||
| self.relation_dim = relation_dim | |||||
| self.p_norm = p_norm | |||||
| self.margin = margin | |||||
| self.epsilon = epsilon | |||||
| self.inner_model = PyKEENTransR( | |||||
| triples_factory=triples_factory, | |||||
| embedding_dim=entity_dim, | |||||
| relation_dim=relation_dim, | |||||
| scoring_fct_norm=1, | |||||
| loss=self.loss, | |||||
| power_norm=p_norm, | |||||
| entity_initializer="xavier_uniform_", | |||||
| relation_initializer="xavier_uniform_", | |||||
| random_seed=42, | |||||
| ).to(device) | |||||
| @property | |||||
| def loss(self) -> torch.Tensor: | |||||
| return MarginRankingLoss(margin=self.margin, reduction="mean") | |||||
| def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||||
| # Parameter initialization | |||||
| if margin is None or epsilon is None: | |||||
| # If no margin/epsilon are provided, use Xavier uniform initialization | |||||
| nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||||
| nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||||
| else: | |||||
| # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||||
| self.embedding_range = nn.Parameter( | |||||
| torch.Tensor([(margin + epsilon) / self.dim]), | |||||
| requires_grad=False, | |||||
| ) | |||||
| nn.init.uniform_( | |||||
| tensor=self.ent_embeddings.weight.data, | |||||
| a=-self.embedding_range.item(), | |||||
| b=self.embedding_range.item(), | |||||
| ) | |||||
| nn.init.uniform_( | |||||
| tensor=self.rel_embeddings.weight.data, | |||||
| a=-self.embedding_range.item(), | |||||
| b=self.embedding_range.item(), | |||||
| ) |
| [tool.ruff] | |||||
| # Exclude a variety of commonly ignored directories. | |||||
| exclude = [ | |||||
| ".bzr", | |||||
| ".direnv", | |||||
| ".eggs", | |||||
| ".git", | |||||
| ".git-rewrite", | |||||
| ".hg", | |||||
| ".ipynb_checkpoints", | |||||
| ".mypy_cache", | |||||
| ".nox", | |||||
| ".pants.d", | |||||
| ".pyenv", | |||||
| ".pytest_cache", | |||||
| ".pytype", | |||||
| ".ruff_cache", | |||||
| ".svn", | |||||
| ".tox", | |||||
| ".venv", | |||||
| ".vscode", | |||||
| "__pypackages__", | |||||
| "_build", | |||||
| "buck-out", | |||||
| "build", | |||||
| "dist", | |||||
| "node_modules", | |||||
| "site-packages", | |||||
| "venv", | |||||
| ] | |||||
| respect-gitignore = true | |||||
| # Black line length is 88 | |||||
| line-length = 100 | |||||
| indent-width = 4 | |||||
| [tool.ruff.lint] | |||||
| # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. | |||||
| # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or | |||||
| # McCabe complexity (`C901`) by default. | |||||
| select = ["ALL"] | |||||
| ignore = ["ANN002", "ANN003", "RET504", "ERA001", "E741", "N812", "D", "PLR0913", "T201", "FBT001", "FBT002", "PLW0602", "PLW0603", "PTH", "C901", "PLR2004", "UP006", "UP035", "G004", "N803", "N806", "ARG002"] | |||||
| [tool.ruff.lint.per-file-ignores] | |||||
| "__init__.py" = ["E", "F", "I", "N"] | |||||
| # Allow fix for all enabled rules (when `--fix`) is provided. | |||||
| fixable = ["ALL"] | |||||
| unfixable = [] | |||||
| # Allow unused variables when underscore-prefixed. | |||||
| #dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" | |||||
| [tool.ruff.format] | |||||
| # Like Black, use double quotes for strings. | |||||
| quote-style = "double" | |||||
| # Like Black, indent with spaces, rather than tabs. | |||||
| indent-style = "space" | |||||
| # Like Black, respect magic trailing commas. | |||||
| skip-magic-trailing-comma = false | |||||
| # Like Black, automatically detect the appropriate line ending. | |||||
| line-ending = "auto" | |||||
| # Enable auto-formatting of code examples in docstrings. Markdown, | |||||
| # reStructuredText code/literal blocks and doctests are all supported. | |||||
| # | |||||
| # This is currently disabled by default, but it is planned for this | |||||
| # to be opt-out in the future. | |||||
| docstring-code-format = false | |||||
| # Set the line length limit used when formatting code snippets in | |||||
| # docstrings. | |||||
| # | |||||
| # This only has an effect when the `docstring-code-format` setting is | |||||
| # enabled. | |||||
| docstring-code-line-length = "dynamic" | |||||
| [tool.pyright] | |||||
| typeCheckingMode = "off" |
| #!/usr/bin/env bash | |||||
| # install.sh – minimal Conda bootstrap with ~/.bashrc check | |||||
| set -e | |||||
| ENV=kg-env # must match your environment.yml | |||||
| YAML=kg_env.yaml # assumed to be in the current dir | |||||
| # ── 1. try to load an existing conda from ~/.bashrc ────────────────────── | |||||
| [ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" || true | |||||
| # ── 2. ensure conda exists ─────────────────────────────────────────────── | |||||
| if ! command -v conda &>/dev/null; then | |||||
| echo "[install.sh] Conda not found → installing Miniconda." | |||||
| curl -fsSL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh | |||||
| bash miniconda.sh -b -p "$HOME/miniconda3" | |||||
| rm miniconda.sh | |||||
| source "$HOME/miniconda3/etc/profile.d/conda.sh" | |||||
| conda init bash >/dev/null # so future shells have it | |||||
| else | |||||
| # conda exists; load its helper for this shell | |||||
| source "$(conda info --base)/etc/profile.d/conda.sh" | |||||
| fi | |||||
| # ── 3. create or update the project env ────────────────────────────────── | |||||
| conda env create -n "$ENV" -f "$YAML" \ | |||||
| || conda env update -n "$ENV" -f "$YAML" --prune | |||||
| echo "✔ Done. Activate with: conda activate $ENV" |
| # environment.yml ── KG-Embeddings project | |||||
| name: kg | |||||
| channels: | |||||
| - pytorch | |||||
| - defaults | |||||
| - nvidia | |||||
| - conda-forge | |||||
| - https://repo.anaconda.com/pkgs/main | |||||
| - https://repo.anaconda.com/pkgs/r | |||||
| dependencies: | |||||
| - _libgcc_mutex=0.1=conda_forge | |||||
| - _openmp_mutex=4.5=2_gnu | |||||
| - aom=3.6.1=h59595ed_0 | |||||
| - blas=1.0=mkl | |||||
| - brotli-python=1.1.0=py311hfdbb021_2 | |||||
| - bzip2=1.0.8=h4bc722e_7 | |||||
| - ca-certificates=2024.12.14=hbcca054_0 | |||||
| - certifi=2024.12.14=pyhd8ed1ab_0 | |||||
| - cffi=1.17.1=py311hf29c0ef_0 | |||||
| - charset-normalizer=3.4.1=pyhd8ed1ab_0 | |||||
| - cpython=3.11.11=py311hd8ed1ab_1 | |||||
| - cuda-cudart=12.4.127=0 | |||||
| - cuda-cupti=12.4.127=0 | |||||
| - cuda-libraries=12.4.1=0 | |||||
| - cuda-nvrtc=12.4.127=0 | |||||
| - cuda-nvtx=12.4.127=0 | |||||
| - cuda-opencl=12.6.77=0 | |||||
| - cuda-runtime=12.4.1=0 | |||||
| - cuda-version=12.6=3 | |||||
| - ffmpeg=4.4.2=gpl_hdf48244_113 | |||||
| - filelock=3.16.1=pyhd8ed1ab_1 | |||||
| - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 | |||||
| - font-ttf-inconsolata=3.000=h77eed37_0 | |||||
| - font-ttf-source-code-pro=2.038=h77eed37_0 | |||||
| - font-ttf-ubuntu=0.83=h77eed37_3 | |||||
| - fontconfig=2.15.0=h7e30c49_1 | |||||
| - fonts-conda-ecosystem=1=0 | |||||
| - fonts-conda-forge=1=0 | |||||
| - freetype=2.12.1=h267a509_2 | |||||
| - gettext=0.22.5=he02047a_3 | |||||
| - gettext-tools=0.22.5=he02047a_3 | |||||
| - giflib=5.2.2=hd590300_0 | |||||
| - gmp=6.3.0=hac33072_2 | |||||
| - gmpy2=2.1.5=py311h0f6cedb_3 | |||||
| - gnutls=3.7.9=hb077bed_0 | |||||
| - h2=4.1.0=pyhd8ed1ab_1 | |||||
| - hpack=4.0.0=pyhd8ed1ab_1 | |||||
| - hyperframe=6.0.1=pyhd8ed1ab_1 | |||||
| - idna=3.10=pyhd8ed1ab_1 | |||||
| - intel-openmp=2022.0.1=h06a4308_3633 | |||||
| - jinja2=3.1.5=pyhd8ed1ab_0 | |||||
| - lame=3.100=h166bdaf_1003 | |||||
| - lcms2=2.16=hb7c19ff_0 | |||||
| - ld_impl_linux-64=2.43=h712a8e2_2 | |||||
| - lerc=4.0.0=h27087fc_0 | |||||
| - libasprintf=0.22.5=he8f35ee_3 | |||||
| - libasprintf-devel=0.22.5=he8f35ee_3 | |||||
| - libblas=3.9.0=16_linux64_mkl | |||||
| - libcblas=3.9.0=16_linux64_mkl | |||||
| - libcublas=12.4.5.8=0 | |||||
| - libcufft=11.2.1.3=0 | |||||
| - libcufile=1.11.1.6=0 | |||||
| - libcurand=10.3.7.77=0 | |||||
| - libcusolver=11.6.1.9=0 | |||||
| - libcusparse=12.3.1.170=0 | |||||
| - libdeflate=1.23=h4ddbbb0_0 | |||||
| - libdrm=2.4.124=hb9d3cd8_0 | |||||
| - libegl=1.7.0=ha4b6fd6_2 | |||||
| - libexpat=2.6.4=h5888daf_0 | |||||
| - libffi=3.4.2=h7f98852_5 | |||||
| - libgcc=14.2.0=h77fa898_1 | |||||
| - libgcc-ng=14.2.0=h69a702a_1 | |||||
| - libgettextpo=0.22.5=he02047a_3 | |||||
| - libgettextpo-devel=0.22.5=he02047a_3 | |||||
| - libgl=1.7.0=ha4b6fd6_2 | |||||
| - libglvnd=1.7.0=ha4b6fd6_2 | |||||
| - libglx=1.7.0=ha4b6fd6_2 | |||||
| - libgomp=14.2.0=h77fa898_1 | |||||
| - libiconv=1.17=hd590300_2 | |||||
| - libidn2=2.3.7=hd590300_0 | |||||
| - libjpeg-turbo=3.0.0=hd590300_1 | |||||
| - liblapack=3.9.0=16_linux64_mkl | |||||
| - liblzma=5.6.3=hb9d3cd8_1 | |||||
| - libnpp=12.2.5.30=0 | |||||
| - libnsl=2.0.1=hd590300_0 | |||||
| - libnvfatbin=12.6.77=0 | |||||
| - libnvjitlink=12.4.127=0 | |||||
| - libnvjpeg=12.3.1.117=0 | |||||
| - libpciaccess=0.18=hd590300_0 | |||||
| - libpng=1.6.45=h943b412_0 | |||||
| - libsqlite=3.47.2=hee588c1_0 | |||||
| - libstdcxx=14.2.0=hc0a3c3a_1 | |||||
| - libstdcxx-ng=14.2.0=h4852527_1 | |||||
| - libtasn1=4.19.0=h166bdaf_0 | |||||
| - libtiff=4.7.0=hd9ff511_3 | |||||
| - libunistring=0.9.10=h7f98852_0 | |||||
| - libuuid=2.38.1=h0b41bf4_0 | |||||
| - libva=2.22.0=h8a09558_1 | |||||
| - libvpx=1.13.1=h59595ed_0 | |||||
| - libwebp=1.5.0=hae8dbeb_0 | |||||
| - libwebp-base=1.5.0=h851e524_0 | |||||
| - libxcb=1.17.0=h8a09558_0 | |||||
| - libxcrypt=4.4.36=hd590300_1 | |||||
| - libxml2=2.13.5=h0d44e9d_1 | |||||
| - libzlib=1.3.1=hb9d3cd8_2 | |||||
| - llvm-openmp=15.0.7=h0cdce71_0 | |||||
| - markupsafe=3.0.2=py311h2dc5d0c_1 | |||||
| - mkl=2022.1.0=hc2b9512_224 | |||||
| - mpc=1.3.1=h24ddda3_1 | |||||
| - mpfr=4.2.1=h90cbb55_3 | |||||
| - mpmath=1.3.0=pyhd8ed1ab_1 | |||||
| - ncurses=6.5=h2d0b736_2 | |||||
| - nettle=3.9.1=h7ab15ed_0 | |||||
| - networkx=3.4.2=pyh267e887_2 | |||||
| - numpy=2.2.1=py311hf916aec_0 | |||||
| - openh264=2.3.1=hcb278e6_2 | |||||
| - openjpeg=2.5.3=h5fbd93e_0 | |||||
| - openssl=3.4.0=h7b32b05_1 | |||||
| - p11-kit=0.24.1=hc5aa10d_0 | |||||
| - pillow=11.1.0=py311h1322bbf_0 | |||||
| - pip=24.3.1=pyh8b19718_2 | |||||
| - pthread-stubs=0.4=hb9d3cd8_1002 | |||||
| - pycparser=2.22=pyh29332c3_1 | |||||
| - pysocks=1.7.1=pyha55dd90_7 | |||||
| - python=3.11.11=h9e4cc4f_1_cpython | |||||
| - python_abi=3.11=5_cp311 | |||||
| - pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0 | |||||
| - pytorch-cuda=12.4=hc786d27_7 | |||||
| - pytorch-mutex=1.0=cuda | |||||
| - pyyaml=6.0.2=py311h9ecbd09_1 | |||||
| - readline=8.2=h8228510_1 | |||||
| - requests=2.32.3=pyhd8ed1ab_1 | |||||
| - setuptools=75.8.0=pyhff2d567_0 | |||||
| - svt-av1=1.4.1=hcb278e6_0 | |||||
| - sympy=1.13.3=pyh2585a3b_105 | |||||
| - tk=8.6.13=noxft_h4845f30_101 | |||||
| - torchaudio=2.5.1=py311_cu124 | |||||
| - torchtriton=3.1.0=py311 | |||||
| - torchvision=0.20.1=py311_cu124 | |||||
| - typing_extensions=4.12.2=pyha770c72_1 | |||||
| - tzdata=2024b=hc8b5060_0 | |||||
| - urllib3=2.3.0=pyhd8ed1ab_0 | |||||
| - wayland=1.23.1=h3e06ad9_0 | |||||
| - wayland-protocols=1.37=hd8ed1ab_0 | |||||
| - wheel=0.45.1=pyhd8ed1ab_1 | |||||
| - x264=1!164.3095=h166bdaf_2 | |||||
| - x265=3.5=h924138e_3 | |||||
| - xorg-libx11=1.8.10=h4f16b4b_1 | |||||
| - xorg-libxau=1.0.12=hb9d3cd8_0 | |||||
| - xorg-libxdmcp=1.1.5=hb9d3cd8_0 | |||||
| - xorg-libxext=1.3.6=hb9d3cd8_0 | |||||
| - xorg-libxfixes=6.0.1=hb9d3cd8_0 | |||||
| - yaml=0.2.5=h7f98852_2 | |||||
| - zstandard=0.23.0=py311hbc35293_1 | |||||
| - zstd=1.5.6=ha6fb4c9_0 | |||||
| - pip: | |||||
| - torchmetrics>=1.4 # nicer MRR/Hits@K utilities | |||||
| - lovely-tensors>=0.1.0 # tensor utilities | |||||
| - lovely-numpy>=0.1.0 # numpy utilities | |||||
| - einops==0.8.1 | |||||
| - pykeen==1.11.1 | |||||
| - hydra-core==1.3.2 | |||||
| - tensorboard==2.19.0 |
| from .pretty_logger import get_pretty_logger | |||||
| from .sampling import DeterministicRandomSampler | |||||
| from .utils import set_seed | |||||
| from .tb_handler import TensorBoardHandler | |||||
| from .params import CommonParams, TrainingParams | |||||
| from .checkpoint_manager import CheckpointManager |
| from __future__ import annotations | |||||
| import re | |||||
| from pathlib import Path | |||||
| class CheckpointManager: | |||||
| """ | |||||
| A checkpoint manager that handles checkpoint directory creation and path resolution. | |||||
| Features: | |||||
| - Creates sequentially numbered checkpoint directories | |||||
| - Supports two loading modes: by components or by full path | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| root_directory: str | Path, | |||||
| run_name: str, | |||||
| load_only: bool = False, | |||||
| ) -> None: | |||||
| """ | |||||
| Initialize the checkpoint manager. | |||||
| Args: | |||||
| root_directory: Root directory for checkpoints | |||||
| run_name: Name of the run (used as suffix for checkpoint directories) | |||||
| """ | |||||
| self.root_directory = Path(root_directory) | |||||
| self.run_name = run_name | |||||
| self.checkpoint_directory = None | |||||
| if not load_only: | |||||
| self.root_directory.mkdir(parents=True, exist_ok=True) | |||||
| self.checkpoint_directory = self._create_checkpoint_directory() | |||||
| else: | |||||
| self.checkpoint_directory = "" | |||||
| def _find_existing_directories(self) -> list[int]: | |||||
| """ | |||||
| Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number. | |||||
| Returns: | |||||
| List of existing sequence numbers | |||||
| """ | |||||
| pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$") | |||||
| existing_numbers = [] | |||||
| if self.root_directory.exists(): | |||||
| for item in self.root_directory.iterdir(): | |||||
| if item.is_dir(): | |||||
| match = pattern.match(item.name) | |||||
| if match: | |||||
| existing_numbers.append(int(match.group(1))) | |||||
| return sorted(existing_numbers) | |||||
| def _create_checkpoint_directory(self) -> Path: | |||||
| """ | |||||
| Create a new checkpoint directory with the next sequential number. | |||||
| Returns: | |||||
| Path to the created checkpoint directory | |||||
| """ | |||||
| existing_numbers = self._find_existing_directories() | |||||
| # Determine the next number | |||||
| next_number = max(existing_numbers) + 1 if existing_numbers else 1 | |||||
| # Create directory name with 3-digit zero-padded number | |||||
| dir_name = f"{next_number:03d}_{self.run_name}" | |||||
| checkpoint_dir = self.root_directory / dir_name | |||||
| self.run_name = dir_name | |||||
| # Create the directory | |||||
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |||||
| return checkpoint_dir | |||||
| def get_checkpoint_directory(self) -> Path: | |||||
| """ | |||||
| Get the current checkpoint directory. | |||||
| Returns: | |||||
| Path to the checkpoint directory | |||||
| """ | |||||
| return self.checkpoint_directory | |||||
| def get_model_fpath(self, model_path: str) -> Path: | |||||
| """ | |||||
| Get the full path to a model checkpoint file. | |||||
| Args: | |||||
| model_path: Either a tuple (model_id, iteration) or a string path | |||||
| Returns: | |||||
| Full path to the model checkpoint file | |||||
| """ | |||||
| try: | |||||
| model_path = eval(model_path) | |||||
| if isinstance(model_path, tuple): | |||||
| model_id, iteration = model_path | |||||
| checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}" | |||||
| filename = f"model-{iteration}.pt" | |||||
| return Path(f"{checkpoint_directory}/{filename}") | |||||
| except (SyntaxError, NameError): | |||||
| pass | |||||
| if isinstance(model_path, str): | |||||
| return Path(model_path) | |||||
| msg = "model_path must be a tuple (model_id, iteration) or a string path" | |||||
| raise ValueError(msg) | |||||
| def get_model_path_by_args( | |||||
| self, | |||||
| model_id: str | None = None, | |||||
| iteration: int | str | None = None, | |||||
| full_path: str | Path | None = None, | |||||
| ) -> Path: | |||||
| """ | |||||
| Get the path for loading a model checkpoint. | |||||
| Two modes of operation: | |||||
| 1. Component mode: Provide model_id and iteration to construct path | |||||
| 2. Full path mode: Provide full_path directly | |||||
| Args: | |||||
| model_id: Model identifier (used in component mode) | |||||
| iteration: Training iteration (used in component mode) | |||||
| full_path: Full path to checkpoint (used in full path mode) | |||||
| Returns: | |||||
| Path to the checkpoint file | |||||
| Raises: | |||||
| ValueError: If neither component parameters nor full_path are provided, | |||||
| or if both modes are attempted simultaneously | |||||
| """ | |||||
| # Check which mode we're operating in | |||||
| component_mode = model_id is not None or iteration is not None | |||||
| full_path_mode = full_path is not None | |||||
| if component_mode and full_path_mode: | |||||
| msg = ( | |||||
| "Cannot use both component mode (model_id/iteration) " | |||||
| "and full path mode simultaneously" | |||||
| ) | |||||
| raise ValueError(msg) | |||||
| if not component_mode and not full_path_mode: | |||||
| msg = "Must provide either (model_id and iteration) or full_path" | |||||
| raise ValueError(msg) | |||||
| if full_path_mode: | |||||
| # Full path mode: return the path as-is | |||||
| return Path(full_path) | |||||
| # Component mode: construct path from checkpoint directory, model_id, and iteration | |||||
| if model_id is None or iteration is None: | |||||
| msg = "Both model_id and iteration must be provided in component mode" | |||||
| raise ValueError(msg) | |||||
| filename = f"{model_id}_iter_{iteration}.pt" | |||||
| return self.checkpoint_directory / filename | |||||
| def save_checkpoint_info(self, info: dict) -> None: | |||||
| """ | |||||
| Save checkpoint information to a JSON file in the checkpoint directory. | |||||
| Args: | |||||
| info: Dictionary containing checkpoint metadata | |||||
| """ | |||||
| import json | |||||
| info_file = self.checkpoint_directory / "checkpoint_info.json" | |||||
| with open(info_file, "w") as f: | |||||
| json.dump(info, f, indent=2) | |||||
| def load_checkpoint_info(self) -> dict: | |||||
| """ | |||||
| Load checkpoint information from the checkpoint directory. | |||||
| Returns: | |||||
| Dictionary containing checkpoint metadata | |||||
| """ | |||||
| import json | |||||
| info_file = self.checkpoint_directory / "checkpoint_info.json" | |||||
| if info_file.exists(): | |||||
| with open(info_file) as f: | |||||
| return json.load(f) | |||||
| return {} | |||||
| def __str__(self) -> str: | |||||
| return ( | |||||
| f"CheckpointManager(root='{self.root_directory}', " | |||||
| f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')" | |||||
| ) | |||||
| def __repr__(self) -> str: | |||||
| return self.__str__() | |||||
| # Example usage: | |||||
| if __name__ == "__main__": | |||||
| import tempfile | |||||
| # Create checkpoint manager with proper temp directory | |||||
| with tempfile.TemporaryDirectory() as temp_dir: | |||||
| ckpt_manager = CheckpointManager(temp_dir, "my_experiment") | |||||
| print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}") | |||||
| # Component mode - construct path from components | |||||
| model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000) | |||||
| print(f"Model path (component mode): {model_path}") | |||||
| # Full path mode - use existing full path | |||||
| full_path = "/some/other/path/model.pt" | |||||
| model_path = ckpt_manager.get_model_path_by_args(full_path=full_path) | |||||
| print(f"Model path (full path mode): {model_path}") |
| from __future__ import annotations | |||||
| from dataclasses import dataclass | |||||
| @dataclass | |||||
| class CommonParams: | |||||
| """ | |||||
| Common parameters for the application. | |||||
| """ | |||||
| model_name: str | |||||
| project_name: str | |||||
| run_name: str | |||||
| save_dpath: str | |||||
| save_every: int = 1000 | |||||
| load_path: str | None = "" | |||||
| log_dir: str = "./runs" | |||||
| log_every: int = 10 | |||||
| log_console_every: int = 100 | |||||
| evaluate_only: bool = False | |||||
| @dataclass | |||||
| class TrainingParams: | |||||
| """ | |||||
| Parameters for training. | |||||
| """ | |||||
| lr: float = 0.001 | |||||
| min_lr: float = 0.0001 | |||||
| weight_decay: float = 0.0 | |||||
| t0: int = 50 | |||||
| lr_step: int = 10 | |||||
| gamma: float = 1.0 | |||||
| batch_size: int = 128 | |||||
| num_workers: int = 0 | |||||
| seed: int = 42 | |||||
| num_train_steps: int = 100 | |||||
| num_epochs: int = 100 | |||||
| eval_every: int = 5 | |||||
| validation_sample_size: int = 1000 | |||||
| validation_batch_size: int = 128 |
| # pretty_logger.py | |||||
| import json | |||||
| import logging | |||||
| import logging.handlers | |||||
| import sys | |||||
| import threading | |||||
| import time | |||||
| from contextlib import ContextDecorator | |||||
| from functools import wraps | |||||
| # ANSI escape codes for colors | |||||
| RESET = "\x1b[0m" | |||||
| BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = ( | |||||
| "\x1b[30m", | |||||
| "\x1b[31m", | |||||
| "\x1b[32m", | |||||
| "\x1b[33m", | |||||
| "\x1b[34m", | |||||
| "\x1b[35m", | |||||
| "\x1b[36m", | |||||
| "\x1b[37m", | |||||
| ) | |||||
| _LEVEL_COLORS = { | |||||
| logging.DEBUG: CYAN, | |||||
| logging.INFO: GREEN, | |||||
| logging.WARNING: YELLOW, | |||||
| logging.ERROR: RED, | |||||
| logging.CRITICAL: MAGENTA, | |||||
| } | |||||
| def _supports_color() -> bool: | |||||
| """ | |||||
| Detect whether the running terminal supports ANSI colors. | |||||
| """ | |||||
| if hasattr(sys.stdout, "isatty") and sys.stdout.isatty(): | |||||
| platform = sys.platform | |||||
| if platform == "win32": | |||||
| # On Windows, modern terminals may support ANSI; assume yes for Windows 10+ | |||||
| return True | |||||
| return True | |||||
| return False | |||||
| class PrettyFormatter(logging.Formatter): | |||||
| """ | |||||
| A logging.Formatter that outputs colorized logs to the console | |||||
| (if enabled) and supports JSON formatting (if requested). | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| fmt: str = None, | |||||
| datefmt: str = "%Y-%m-%d %H:%M:%S", | |||||
| use_colors: bool = True, | |||||
| json_output: bool = False, | |||||
| ): | |||||
| super().__init__(fmt=fmt, datefmt=datefmt) | |||||
| self.use_colors = use_colors and _supports_color() | |||||
| self.json_output = json_output | |||||
| def format(self, record: logging.LogRecord) -> str: | |||||
| # If JSON output is requested, dump a dict of relevant fields | |||||
| if self.json_output: | |||||
| log_record = { | |||||
| "timestamp": self.formatTime(record, self.datefmt), | |||||
| "name": record.name, | |||||
| "level": record.levelname, | |||||
| "message": record.getMessage(), | |||||
| } | |||||
| # Include extra/contextual fields | |||||
| for k, v in record.__dict__.get("extra_fields", {}).items(): | |||||
| log_record[k] = v | |||||
| if record.exc_info: | |||||
| log_record["exception"] = self.formatException(record.exc_info) | |||||
| return json.dumps(log_record, ensure_ascii=False) | |||||
| # Otherwise, build a colorized/“pretty” line | |||||
| level_color = _LEVEL_COLORS.get(record.levelno, WHITE) if self.use_colors else "" | |||||
| reset = RESET if self.use_colors else "" | |||||
| timestamp = self.formatTime(record, self.datefmt) | |||||
| name = record.name | |||||
| level = record.levelname | |||||
| # Basic format: [timestamp] [LEVEL] [name]: message | |||||
| base = f"[{timestamp}] [{level}] [{name}]: {record.getMessage()}" | |||||
| # If there are any extra fields attached via LoggerAdapter, append them | |||||
| extra_fields = record.__dict__.get("extra_fields", {}) | |||||
| if extra_fields: | |||||
| # key1=val1 key2=val2 … | |||||
| extras_str = " ".join(f"{k}={v}" for k, v in extra_fields.items()) | |||||
| base = f"{base} :: {extras_str}" | |||||
| # If exception info is present, include traceback | |||||
| if record.exc_info: | |||||
| exc_text = self.formatException(record.exc_info) | |||||
| base = f"{base}\n{exc_text}" | |||||
| if self.use_colors: | |||||
| return f"{level_color}{base}{reset}" | |||||
| else: | |||||
| return base | |||||
| class JsonFormatter(PrettyFormatter): | |||||
| """ | |||||
| Shortcut to force JSON formatting (ignores color flags). | |||||
| """ | |||||
| def __init__(self, datefmt: str = "%Y-%m-%d %H:%M:%S"): | |||||
| super().__init__(fmt=None, datefmt=datefmt, use_colors=False, json_output=True) | |||||
| class ContextFilter(logging.Filter): | |||||
| """ | |||||
| Inject extra_fields from a LoggerAdapter into the LogRecord, so the Formatter can find them. | |||||
| """ | |||||
| def filter(self, record: logging.LogRecord) -> bool: | |||||
| extra = getattr(record, "extra", None) | |||||
| if isinstance(extra, dict): | |||||
| record.extra_fields = extra | |||||
| else: | |||||
| record.extra_fields = {} | |||||
| return True | |||||
| class Timer(ContextDecorator): | |||||
| """ | |||||
| Context manager & decorator to measure execution time of a code block or function. | |||||
| Usage as context manager: | |||||
| with Timer(logger, "block name"): | |||||
| # code … | |||||
| Usage as decorator: | |||||
| @Timer(logger, "function foo") | |||||
| def foo(...): | |||||
| ... | |||||
| """ | |||||
| def __init__(self, logger: logging.Logger, name: str = None, level: int = logging.INFO): | |||||
| self.logger = logger | |||||
| self.name = name or "" | |||||
| self.level = level | |||||
| def __enter__(self): | |||||
| self.start_time = time.time() | |||||
| return self | |||||
| def __exit__(self, exc_type, exc, exc_tb): | |||||
| elapsed = time.time() - self.start_time | |||||
| label = f"[Timer:{self.name}]" if self.name else "[Timer]" | |||||
| self.logger.log(self.level, f"{label} elapsed {elapsed:.4f} sec") | |||||
| return False # do not suppress exceptions | |||||
| def log_function(logger: logging.Logger, level: int = logging.DEBUG): | |||||
| """ | |||||
| Decorator factory to log function entry/exit and execution time. | |||||
| Usage: | |||||
| @log_function(logger, level=logging.INFO) | |||||
| def my_func(...): | |||||
| ... | |||||
| """ | |||||
| def decorator(func): | |||||
| @wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| logger.log(level, f"➤ Entering {func.__name__}()") | |||||
| start = time.time() | |||||
| try: | |||||
| result = func(*args, **kwargs) | |||||
| except Exception: | |||||
| logger.exception(f"✖ Exception in {func.__name__}()") | |||||
| raise | |||||
| elapsed = time.time() - start | |||||
| logger.log(level, f"✔ Exiting {func.__name__}(), elapsed {elapsed:.4f} sec") | |||||
| return result | |||||
| return wrapper | |||||
| return decorator | |||||
| class PrettyLogger: | |||||
| """ | |||||
| Encapsulates a logging.Logger with “pretty” formatting, colorized console output, | |||||
| rotating file handler, JSON option, contextual fields, etc. | |||||
| """ | |||||
| _instances = {} | |||||
| _lock = threading.Lock() | |||||
| def __new__(cls, name: str = "root", **kwargs): | |||||
| """ | |||||
| Implements a simple singleton: multiple calls with the same name return the same instance. | |||||
| """ | |||||
| with cls._lock: | |||||
| if name not in cls._instances: | |||||
| instance = super().__new__(cls) | |||||
| cls._instances[name] = instance | |||||
| return cls._instances[name] | |||||
| def __init__( | |||||
| self, | |||||
| name: str = "root", | |||||
| level: int = logging.DEBUG, | |||||
| log_to_console: bool = True, | |||||
| console_level: int = logging.DEBUG, | |||||
| log_to_file: bool = False, | |||||
| log_file: str = "app.log", | |||||
| file_level: int = logging.INFO, | |||||
| max_bytes: int = 10 * 1024 * 1024, # 10 MB | |||||
| backup_count: int = 5, | |||||
| formatter: PrettyFormatter = None, | |||||
| json_output: bool = False, | |||||
| ): | |||||
| """ | |||||
| Initialize (or re‐configure) a PrettyLogger. | |||||
| Args: | |||||
| name: logger name. | |||||
| level: root level for the logger (lowest level that will be processed). | |||||
| log_to_console: whether to attach a console (StreamHandler). | |||||
| console_level: level for console output. | |||||
| log_to_file: whether to attach a rotating file handler. | |||||
| log_file: path to the log file. | |||||
| file_level: level for file output. | |||||
| max_bytes: rotation threshold in bytes. | |||||
| backup_count: number of backup files to keep. | |||||
| formatter: an instance of PrettyFormatter. If None, a default is created. | |||||
| json_output: if True, forces JSON formatting on all handlers. | |||||
| """ | |||||
| # Avoid re‐initializing if already done | |||||
| logger = logging.getLogger(name) | |||||
| if getattr(logger, "_pretty_logger_inited", False): | |||||
| return | |||||
| self.logger = logger | |||||
| self.logger.setLevel(level) | |||||
| self.logger.propagate = False # avoid double‐logging if root logger is configured elsewhere | |||||
| # Create a default formatter if not supplied | |||||
| if formatter is None: | |||||
| fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||||
| formatter = PrettyFormatter( | |||||
| fmt=fmt, use_colors=not json_output, json_output=json_output | |||||
| ) | |||||
| # Add ContextFilter so extra_fields is always present | |||||
| context_filter = ContextFilter() | |||||
| self.logger.addFilter(context_filter) | |||||
| # Console handler | |||||
| if log_to_console: | |||||
| ch = logging.StreamHandler(sys.stdout) | |||||
| ch.setLevel(console_level) | |||||
| ch.setFormatter(formatter) | |||||
| self.logger.addHandler(ch) | |||||
| # Rotating file handler | |||||
| if log_to_file: | |||||
| fh = logging.handlers.RotatingFileHandler( | |||||
| filename=log_file, | |||||
| maxBytes=max_bytes, | |||||
| backupCount=backup_count, | |||||
| encoding="utf-8", | |||||
| ) | |||||
| fh.setLevel(file_level) | |||||
| # For file, typically we don’t colorize | |||||
| file_formatter = ( | |||||
| JsonFormatter(datefmt="%Y-%m-%d %H:%M:%S") | |||||
| if json_output | |||||
| else PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||||
| ) | |||||
| fh.setFormatter(file_formatter) | |||||
| self.logger.addHandler(fh) | |||||
| # Mark as initialized | |||||
| setattr(self.logger, "_pretty_logger_inited", True) | |||||
| def addHandler(self, handler: logging.Handler): | |||||
| """ | |||||
| Add a custom handler to the logger. | |||||
| This can be used to add any logging.Handler instance. | |||||
| """ | |||||
| self.logger.addHandler(handler) | |||||
| def bind(self, **kwargs): | |||||
| """ | |||||
| Return a LoggerAdapter that automatically injects extra fields into all log records. | |||||
| Example: | |||||
| ctx_logger = pretty_logger.bind(request_id=123, user="alice") | |||||
| ctx_logger.info("Something happened") # will include request_id and user in the output | |||||
| """ | |||||
| return logging.LoggerAdapter(self.logger, {"extra": kwargs}) | |||||
| def add_stream_handler( | |||||
| self, stream=None, level: int = logging.DEBUG, formatter: PrettyFormatter = None | |||||
| ): | |||||
| """ | |||||
| Add an additional stream handler (e.g. to stderr). | |||||
| """ | |||||
| handler = logging.StreamHandler(stream or sys.stdout) | |||||
| handler.setLevel(level) | |||||
| if formatter is None: | |||||
| fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||||
| formatter = PrettyFormatter(fmt=fmt) | |||||
| handler.setFormatter(formatter) | |||||
| self.logger.addHandler(handler) | |||||
| return handler | |||||
| def add_rotating_file_handler( | |||||
| self, | |||||
| log_file: str, | |||||
| level: int = logging.INFO, | |||||
| max_bytes: int = 10 * 1024 * 1024, | |||||
| backup_count: int = 5, | |||||
| json_output: bool = False, | |||||
| ): | |||||
| """ | |||||
| Add another rotating file handler at runtime. | |||||
| """ | |||||
| fh = logging.handlers.RotatingFileHandler( | |||||
| filename=log_file, | |||||
| maxBytes=max_bytes, | |||||
| backupCount=backup_count, | |||||
| encoding="utf-8", | |||||
| ) | |||||
| fh.setLevel(level) | |||||
| if json_output: | |||||
| formatter = JsonFormatter() | |||||
| else: | |||||
| fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||||
| formatter = PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||||
| fh.setFormatter(formatter) | |||||
| self.logger.addHandler(fh) | |||||
| return fh | |||||
| def remove_handler(self, handler: logging.Handler): | |||||
| """ | |||||
| Remove a handler from the logger. | |||||
| """ | |||||
| self.logger.removeHandler(handler) | |||||
| def set_level(self, level: int): | |||||
| """ | |||||
| Change the logger’s root level at runtime. | |||||
| """ | |||||
| self.logger.setLevel(level) | |||||
| def get_logger(self) -> logging.Logger: | |||||
| return self.logger | |||||
| # Convenience methods to expose common logging calls directly from this object: | |||||
| def debug(self, msg, *args, **kwargs): | |||||
| self.logger.debug(msg, *args, **kwargs) | |||||
| def info(self, msg, *args, **kwargs): | |||||
| self.logger.info(msg, *args, **kwargs) | |||||
| def warning(self, msg, *args, **kwargs): | |||||
| self.logger.warning(msg, *args, **kwargs) | |||||
| def error(self, msg, *args, **kwargs): | |||||
| self.logger.error(msg, *args, **kwargs) | |||||
| def critical(self, msg, *args, **kwargs): | |||||
| self.logger.critical(msg, *args, **kwargs) | |||||
| def exception(self, msg, *args, exc_info=True, **kwargs): | |||||
| """ | |||||
| Log an error message with exception traceback. By default exc_info=True. | |||||
| """ | |||||
| self.logger.error(msg, *args, exc_info=exc_info, **kwargs) | |||||
| # Convenience function for a module‐level “global” pretty logger: | |||||
| _global_loggers = {} | |||||
| _global_lock = threading.Lock() | |||||
| def get_pretty_logger( | |||||
| name: str = "root", | |||||
| level: int = logging.DEBUG, | |||||
| log_to_console: bool = True, | |||||
| console_level: int = logging.DEBUG, | |||||
| log_to_file: bool = False, | |||||
| log_file: str = "app.log", | |||||
| file_level: int = logging.INFO, | |||||
| max_bytes: int = 10 * 1024 * 1024, | |||||
| backup_count: int = 5, | |||||
| json_output: bool = False, | |||||
| ) -> PrettyLogger: | |||||
| """ | |||||
| Return a PrettyLogger instance for `name`, creating/configuring it on first use. | |||||
| Subsequent calls with the same `name` will return the same logger (singleton‐style). | |||||
| """ | |||||
| with _global_lock: | |||||
| if name not in _global_loggers: | |||||
| pl = PrettyLogger( | |||||
| name=name, | |||||
| level=level, | |||||
| log_to_console=log_to_console, | |||||
| console_level=console_level, | |||||
| log_to_file=log_to_file, | |||||
| log_file=log_file, | |||||
| file_level=file_level, | |||||
| max_bytes=max_bytes, | |||||
| backup_count=backup_count, | |||||
| json_output=json_output, | |||||
| ) | |||||
| _global_loggers[name] = pl | |||||
| return _global_loggers[name] |
| import torch | |||||
| import numpy as np | |||||
| from torch.utils.data.sampler import Sampler | |||||
| def negative_sampling( | |||||
| pos: torch.Tensor, | |||||
| num_entities: int, | |||||
| num_negatives: int = 1, | |||||
| mode: list = ["head", "tail"], | |||||
| ) -> torch.Tensor: | |||||
| pos = pos.repeat_interleave(num_negatives, dim=0) | |||||
| if "head" in mode: | |||||
| neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||||
| pos[:, 0] = neg | |||||
| if "tail" in mode: | |||||
| neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||||
| pos[:, 2] = neg | |||||
| return pos | |||||
| class DeterministicRandomSampler(Sampler): | |||||
| """ | |||||
| A deterministic random sampler that selects a fixed number of samples | |||||
| in a reproducible manner using a given seed. | |||||
| """ | |||||
| def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None: | |||||
| """ | |||||
| Args: | |||||
| dataset (Dataset): PyTorch dataset. | |||||
| sample_size (int, optional): Number of samples to draw. If None, use full dataset. | |||||
| seed (int): Seed for reproducibility. | |||||
| """ | |||||
| self.dataset_size = dataset_size | |||||
| self.seed = seed | |||||
| self.sample_size = sample_size if sample_size != 0 else self.dataset_size | |||||
| # Ensure sample size is within dataset size | |||||
| if self.sample_size > self.dataset_size: | |||||
| msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}." | |||||
| raise ValueError( | |||||
| msg, | |||||
| ) | |||||
| self.indices = self._generate_deterministic_indices() | |||||
| def _generate_deterministic_indices(self) -> list[int]: | |||||
| """ | |||||
| Generates a fixed random subset of indices using the given seed. | |||||
| """ | |||||
| rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility | |||||
| all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices | |||||
| return all_indices[: self.sample_size].tolist() # Select only the desired number of samples | |||||
| def __iter__(self) -> iter: | |||||
| """ | |||||
| Yields the shuffled dataset indices. | |||||
| """ | |||||
| return iter(self.indices) | |||||
| def __len__(self) -> int: | |||||
| """ | |||||
| Returns the total number of samples drawn. | |||||
| """ | |||||
| return self.sample_size |
| from __future__ import annotations | |||||
| import logging | |||||
| from typing import TYPE_CHECKING | |||||
| if TYPE_CHECKING: | |||||
| from torch.utils.tensorboard import SummaryWriter | |||||
| class TensorBoardHandler(logging.Handler): | |||||
| """Convert each log record into a TensorBoard text summary.""" | |||||
| def __init__(self, writer: SummaryWriter, tag: str = "logs"): | |||||
| super().__init__(level=logging.INFO) | |||||
| self.writer = writer | |||||
| self.tag = tag | |||||
| self._step = 0 # will be incremented every emit | |||||
| def emit(self, record: logging.LogRecord) -> None: | |||||
| self.writer.add_text(self.tag, self.format(record), global_step=self._step) | |||||
| self._step += 1 |
| from torch.optim.lr_scheduler import StepLR | |||||
| class StepMinLR(StepLR): | |||||
| def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1): | |||||
| self.min_lr = min_lr | |||||
| super().__init__(optimizer, step_size, gamma, last_epoch) | |||||
| def get_lr(self): | |||||
| # use StepLR's formula, then clamp | |||||
| raw_lrs = super().get_lr() | |||||
| return [max(lr, self.min_lr) for lr in raw_lrs] |
| import random | |||||
| from enum import Enum | |||||
| import numpy as np | |||||
| import torch | |||||
| def set_seed(seed: int = 42) -> None: | |||||
| """ | |||||
| Set the random seed for reproducibility. | |||||
| Args: | |||||
| seed (int): The seed value to set for random number generation. | |||||
| """ | |||||
| torch.manual_seed(seed) | |||||
| random.seed(seed) | |||||
| np.random.default_rng(seed) | |||||
| if torch.cuda.is_available(): | |||||
| torch.cuda.manual_seed_all(seed) | |||||
| class Target(str, Enum): # ↔ PyKEEN's LABEL_* constants | |||||
| HEAD = "head" | |||||
| TAIL = "tail" | |||||
| RELATION = "relation" |
| from __future__ import annotations | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import TYPE_CHECKING | |||||
| import torch | |||||
| from torch import nn | |||||
| if TYPE_CHECKING: | |||||
| from torch.optim.lr_scheduler import LRScheduler | |||||
| class ModelTrainerBase(ABC): | |||||
| def __init__( | |||||
| self, | |||||
| model: nn.Module, | |||||
| optimizer: torch.optim.Optimizer, | |||||
| scheduler: LRScheduler | None = None, | |||||
| device: torch.device = "cpu", | |||||
| ) -> None: | |||||
| self.model = model.to(device) | |||||
| self.optimizer = optimizer | |||||
| self.scheduler = scheduler | |||||
| self.device = device | |||||
| def save(self, save_dpath: str, step: int) -> None: | |||||
| """ | |||||
| Save the model and optimizer state to the specified directory. | |||||
| Args: | |||||
| save_dpath (str): The directory path where the model and optimizer state will be saved. | |||||
| """ | |||||
| data = { | |||||
| "model": self.model.state_dict(), | |||||
| "optimizer": self.optimizer.state_dict(), | |||||
| "step": step, | |||||
| } | |||||
| if self.scheduler is not None: | |||||
| data["scheduler"] = self.scheduler.state_dict() | |||||
| torch.save( | |||||
| data, | |||||
| f"{save_dpath}/model-{step}.pt", | |||||
| ) | |||||
| def load(self, load_fpath: str) -> None: | |||||
| """ | |||||
| Load the model and optimizer state from the specified directory. | |||||
| Args: | |||||
| load_dpath (str): The directory path from which the model and | |||||
| optimizer state will be loaded. | |||||
| """ | |||||
| data = torch.load(load_fpath, map_location="cpu") | |||||
| model_state_dict = data["model"] | |||||
| self.model.load_state_dict(model_state_dict) # move everything to the correct device | |||||
| self.model.to(self.device) | |||||
| optimizer_state_dict = data["optimizer"] | |||||
| self.optimizer.load_state_dict(optimizer_state_dict) | |||||
| if self.scheduler is not None: | |||||
| scheduler_data = data.get("scheduler", {}) | |||||
| self.scheduler.load_state_dict(scheduler_data) | |||||
| return data["step"] | |||||
| @abstractmethod | |||||
| def _loss(self, batch: torch.Tensor) -> torch.Tensor: | |||||
| """ | |||||
| Compute the loss for a batch of data. | |||||
| Args: | |||||
| batch (torch.Tensor): A tensor containing a batch of data. | |||||
| Returns: | |||||
| torch.Tensor: The computed loss for the batch. | |||||
| """ | |||||
| ... | |||||
| @abstractmethod | |||||
| def train_step(self, batch: torch.Tensor) -> float: ... | |||||
| @abstractmethod | |||||
| def eval_step(self, batch: torch.Tensor) -> float: ... |
| import torch | |||||
| import torch.nn.functional as F | |||||
| from einops import repeat | |||||
| from pykeen.sampling import BernoulliNegativeSampler | |||||
| from pykeen.training import SLCWATrainingLoop | |||||
| from pykeen.triples.instances import SLCWABatch | |||||
| from torch.optim import Adam | |||||
| from torch.optim.lr_scheduler import StepLR | |||||
| from torch.utils.data import DataLoader | |||||
| from metrics.ranking import simple_ranking | |||||
| from models.translation.trans_e import TransE | |||||
| from tools.params import TrainingParams | |||||
| from training.model_trainers.model_trainer_base import ModelTrainerBase | |||||
| class TransETrainer(ModelTrainerBase): | |||||
| def __init__( | |||||
| self, | |||||
| model: TransE, | |||||
| training_params: TrainingParams, | |||||
| triples_factory: torch.Tensor, | |||||
| mapped_triples: torch.Tensor, | |||||
| device: torch.device, | |||||
| margin: float = 1.0, | |||||
| n_negative: int = 1, | |||||
| loss_fn: str = "margin", | |||||
| **kwargs, | |||||
| ) -> None: | |||||
| optimizer = Adam( | |||||
| model.inner_model.parameters(), | |||||
| lr=training_params.lr, | |||||
| weight_decay=training_params.weight_decay, | |||||
| ) | |||||
| # scheduler = None # Placeholder for scheduler, can be implemented later | |||||
| scheduler = StepLR( | |||||
| optimizer, | |||||
| step_size=training_params.lr_step, | |||||
| gamma=training_params.gamma, | |||||
| ) | |||||
| super().__init__(model, optimizer, scheduler, device) | |||||
| self.triples_factory = triples_factory | |||||
| self.margin = margin | |||||
| self.n_negative = n_negative | |||||
| self.loss_fn = loss_fn | |||||
| self.negative_sampler = BernoulliNegativeSampler( | |||||
| mapped_triples=mapped_triples.to(device), | |||||
| num_entities=model.num_entities, | |||||
| num_relations=model.num_relations, | |||||
| num_negs_per_pos=n_negative, | |||||
| ).to(device) | |||||
| self.training_loop = SLCWATrainingLoop( | |||||
| model=model.inner_model, | |||||
| triples_factory=triples_factory, | |||||
| optimizer=optimizer, | |||||
| negative_sampler=self.negative_sampler, | |||||
| lr_scheduler=scheduler, | |||||
| ) | |||||
| def create_data_loader( | |||||
| self, | |||||
| triples_factory: torch.Tensor, | |||||
| batch_size: int, | |||||
| shuffle: bool = True, | |||||
| ) -> DataLoader: | |||||
| triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||||
| return self.training_loop._create_training_data_loader( | |||||
| triples_factory=triples_factory, | |||||
| sampler=None, | |||||
| batch_size=batch_size, | |||||
| shuffle=shuffle, | |||||
| drop_last=False, | |||||
| ) | |||||
| def _loss(self, pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor: | |||||
| """ | |||||
| Compute the loss based on the positive and negative scores. | |||||
| Args: | |||||
| pos_scores (torch.Tensor): Scores for positive triples. | |||||
| neg_scores (torch.Tensor): Scores for negative triples. | |||||
| Returns: | |||||
| torch.Tensor: The computed loss. | |||||
| """ | |||||
| k = neg_scores.shape[1] | |||||
| if self.loss_fn == "margin": | |||||
| target = -torch.ones_like(neg_scores) # want s_pos < s_neg | |||||
| loss = F.margin_ranking_loss( | |||||
| repeat(pos_scores, "b -> b k", k=k), | |||||
| neg_scores, | |||||
| target, | |||||
| margin=self.margin, | |||||
| ) | |||||
| return loss | |||||
| if self.loss_fn == "adversarial": | |||||
| # b = pos_scores.size(0) | |||||
| # neg_scores = rearrange(neg_scores, "(b k) -> b k", b=b, k=k) | |||||
| # importance weights (detach so gradients flow only through log-sigmoid terms) | |||||
| w = F.softmax(self.adv_temp * neg_scores, dim=1).detach() | |||||
| pos_term = -F.logsigmoid(self.margin - pos_scores) # (B,) | |||||
| neg_term = -(w * F.logsigmoid(neg_scores - self.margin)).sum(dim=1) # (B,) | |||||
| loss = (pos_term + neg_term).mean() | |||||
| return loss | |||||
| return None | |||||
| def train_step(self, batch: SLCWABatch, step: int) -> float: | |||||
| """ | |||||
| Perform a training step on the given batch of data. | |||||
| Args: | |||||
| batch (torch.Tensor): A tensor containing a batch of triples. | |||||
| Returns: | |||||
| float: The computed loss for the batch. | |||||
| """ | |||||
| continue_training = step > 0 | |||||
| loss = self.training_loop.train( | |||||
| triples_factory=self.triples_factory, | |||||
| num_epochs=step + 1, | |||||
| batch_size=self.training_loop._get_batch_size(batch), | |||||
| continue_training=continue_training, | |||||
| use_tqdm=False, | |||||
| use_tqdm_batch=False, | |||||
| label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||||
| # num_workers=3, | |||||
| )[-1] | |||||
| return loss | |||||
| def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||||
| """Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||||
| self.model.eval() | |||||
| with torch.no_grad(): | |||||
| return simple_ranking(self.model, batch.to(self.device)) |
| import torch | |||||
| from pykeen.losses import MarginRankingLoss | |||||
| from pykeen.sampling import BernoulliNegativeSampler | |||||
| from pykeen.training import SLCWATrainingLoop | |||||
| from pykeen.triples.instances import SLCWABatch | |||||
| from torch.optim import Adam | |||||
| from metrics.ranking import simple_ranking | |||||
| from models.translation.trans_h import TransH | |||||
| from tools.params import TrainingParams | |||||
| from tools.train import StepMinLR | |||||
| from training.model_trainers.model_trainer_base import ModelTrainerBase | |||||
| class TransHTrainer(ModelTrainerBase): | |||||
| def __init__( | |||||
| self, | |||||
| model: TransH, | |||||
| training_params: TrainingParams, | |||||
| triples_factory: torch.Tensor, | |||||
| mapped_triples: torch.Tensor, | |||||
| device: torch.device, | |||||
| margin: float = 1.0, | |||||
| n_negative: int = 1, | |||||
| regul_rate: float = 0.0, | |||||
| loss_fn: str = "margin", | |||||
| **kwargs, | |||||
| ) -> None: | |||||
| optimizer = Adam( | |||||
| model.inner_model.parameters(), | |||||
| lr=training_params.lr, | |||||
| weight_decay=training_params.weight_decay, | |||||
| ) | |||||
| scheduler = StepMinLR( | |||||
| optimizer, | |||||
| step_size=training_params.lr_step, | |||||
| gamma=training_params.gamma, | |||||
| min_lr=training_params.min_lr, | |||||
| ) | |||||
| super().__init__(model, optimizer, scheduler, device) | |||||
| self.triples_factory = triples_factory | |||||
| self.margin = margin | |||||
| self.n_negative = n_negative | |||||
| self.regul_rate = regul_rate | |||||
| self.loss_fn = loss_fn | |||||
| self.negative_sampler = BernoulliNegativeSampler( | |||||
| mapped_triples=mapped_triples.to(device), | |||||
| num_entities=model.num_entities, | |||||
| num_relations=model.num_relations, | |||||
| num_negs_per_pos=n_negative, | |||||
| ).to(device) | |||||
| self.training_loop = SLCWATrainingLoop( | |||||
| model=model.inner_model, | |||||
| triples_factory=triples_factory, | |||||
| optimizer=optimizer, | |||||
| negative_sampler=self.negative_sampler, | |||||
| lr_scheduler=scheduler, | |||||
| ) | |||||
| def _loss(self) -> torch.Tensor: | |||||
| return self.model.loss | |||||
| def create_data_loader( | |||||
| self, | |||||
| triples_factory: torch.Tensor, | |||||
| batch_size: int, | |||||
| shuffle: bool = True, | |||||
| ) -> torch.utils.data.DataLoader: | |||||
| triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||||
| return self.training_loop._create_training_data_loader( | |||||
| triples_factory=triples_factory, | |||||
| sampler=None, | |||||
| batch_size=batch_size, | |||||
| shuffle=shuffle, | |||||
| drop_last=False, | |||||
| ) | |||||
| @property | |||||
| def loss(self) -> torch.Tensor: | |||||
| return MarginRankingLoss(margin=self.margin, reduction="mean") | |||||
| def train_step( | |||||
| self, | |||||
| batch: SLCWABatch, | |||||
| step: int, | |||||
| ) -> float: | |||||
| """ | |||||
| Perform a training step on the given batch of data. | |||||
| Args: | |||||
| batch (torch.Tensor): A tensor containing a batch of triples. | |||||
| Returns: | |||||
| float: The computed loss for the batch. | |||||
| """ | |||||
| continue_training = step > 0 | |||||
| loss = self.training_loop.train( | |||||
| triples_factory=self.triples_factory, | |||||
| num_epochs=step + 1, | |||||
| batch_size=self.training_loop._get_batch_size(batch), | |||||
| continue_training=continue_training, | |||||
| use_tqdm=False, | |||||
| use_tqdm_batch=False, | |||||
| label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||||
| )[-1] | |||||
| return loss | |||||
| def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||||
| """Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||||
| self.model.eval() | |||||
| with torch.no_grad(): | |||||
| return simple_ranking(self.model, batch.to(self.device)) |
| import torch | |||||
| from pykeen.sampling import BernoulliNegativeSampler | |||||
| from pykeen.training import SLCWATrainingLoop | |||||
| from pykeen.triples.instances import SLCWABatch | |||||
| from torch.optim import Adam | |||||
| from metrics.ranking import simple_ranking | |||||
| from models.translation.trans_r import TransR | |||||
| from tools.params import TrainingParams | |||||
| from tools.train import StepMinLR | |||||
| from training.model_trainers.model_trainer_base import ModelTrainerBase | |||||
| class TransRTrainer(ModelTrainerBase): | |||||
| def __init__( | |||||
| self, | |||||
| model: TransR, | |||||
| training_params: TrainingParams, | |||||
| triples_factory: torch.Tensor, | |||||
| mapped_triples: torch.Tensor, | |||||
| device: torch.device, | |||||
| margin: float = 1.0, | |||||
| n_negative: int = 1, | |||||
| regul_rate: float = 0.0, | |||||
| loss_fn: str = "margin", | |||||
| **kwargs, | |||||
| ) -> None: | |||||
| optimizer = Adam( | |||||
| model.inner_model.parameters(), | |||||
| lr=training_params.lr, | |||||
| weight_decay=training_params.weight_decay, | |||||
| ) | |||||
| scheduler = StepMinLR( | |||||
| optimizer, | |||||
| step_size=training_params.lr_step, | |||||
| gamma=training_params.gamma, | |||||
| min_lr=training_params.min_lr, | |||||
| ) | |||||
| super().__init__(model, optimizer, scheduler, device) | |||||
| self.triples_factory = triples_factory | |||||
| self.margin = margin | |||||
| self.n_negative = n_negative | |||||
| self.regul_rate = regul_rate | |||||
| self.loss_fn = loss_fn | |||||
| self.negative_sampler = BernoulliNegativeSampler( | |||||
| mapped_triples=mapped_triples.to(device), | |||||
| num_entities=model.num_entities, | |||||
| num_relations=model.num_relations, | |||||
| num_negs_per_pos=n_negative, | |||||
| ).to(device) | |||||
| self.training_loop = SLCWATrainingLoop( | |||||
| model=model.inner_model, | |||||
| triples_factory=triples_factory, | |||||
| optimizer=optimizer, | |||||
| negative_sampler=self.negative_sampler, | |||||
| lr_scheduler=scheduler, | |||||
| ) | |||||
| def create_data_loader( | |||||
| self, | |||||
| triples_factory: torch.Tensor, | |||||
| batch_size: int, | |||||
| shuffle: bool = True, | |||||
| ) -> torch.utils.data.DataLoader: | |||||
| triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||||
| return self.training_loop._create_training_data_loader( | |||||
| triples_factory=triples_factory, | |||||
| sampler=None, | |||||
| batch_size=batch_size, | |||||
| shuffle=shuffle, | |||||
| drop_last=False, | |||||
| ) | |||||
| def _loss(self) -> torch.Tensor: | |||||
| return self.model.loss | |||||
| def train_step( | |||||
| self, | |||||
| batch: SLCWABatch, | |||||
| step: int, | |||||
| ) -> float: | |||||
| """ | |||||
| Perform a training step on the given batch of data. | |||||
| Args: | |||||
| batch (torch.Tensor): A tensor containing a batch of triples. | |||||
| Returns: | |||||
| float: The computed loss for the batch. | |||||
| """ | |||||
| continue_training = step > 0 | |||||
| loss = self.training_loop.train( | |||||
| triples_factory=self.triples_factory, | |||||
| num_epochs=step + 1, | |||||
| batch_size=self.training_loop._get_batch_size(batch), | |||||
| continue_training=continue_training, | |||||
| use_tqdm=False, | |||||
| use_tqdm_batch=False, | |||||
| label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||||
| )[-1] | |||||
| return loss | |||||
| def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||||
| """Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||||
| self.model.eval() | |||||
| with torch.no_grad(): | |||||
| return simple_ranking(self.model, batch.to(self.device)) |
| from __future__ import annotations | |||||
| from pathlib import Path | |||||
| from typing import TYPE_CHECKING, Any | |||||
| import torch | |||||
| from pykeen.evaluation import RankBasedEvaluator | |||||
| from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.utils.tensorboard import SummaryWriter | |||||
| from tools import ( | |||||
| CheckpointManager, | |||||
| DeterministicRandomSampler, | |||||
| TensorBoardHandler, | |||||
| get_pretty_logger, | |||||
| set_seed, | |||||
| ) | |||||
| if TYPE_CHECKING: | |||||
| from data.kg_dataset import KGDataset | |||||
| from tools import CommonParams, TrainingParams | |||||
| from .model_trainers.model_trainer_base import ModelTrainerBase | |||||
| logger = get_pretty_logger(__name__) | |||||
| cpu_device = torch.device("cpu") | |||||
| class Trainer: | |||||
| def __init__( | |||||
| self, | |||||
| train_dataset: KGDataset, | |||||
| val_dataset: KGDataset | None, | |||||
| model_trainer: ModelTrainerBase, | |||||
| common_params: CommonParams, | |||||
| training_params: TrainingParams, | |||||
| device: torch.device = cpu_device, | |||||
| ) -> None: | |||||
| set_seed(training_params.seed) | |||||
| self.train_dataset = train_dataset | |||||
| self.val_dataset = val_dataset if val_dataset is not None else train_dataset | |||||
| self.model_trainer = model_trainer | |||||
| self.common_params = common_params | |||||
| self.training_params = training_params | |||||
| self.device = device | |||||
| self.train_loader = self.model_trainer.create_data_loader( | |||||
| triples_factory=self.train_dataset.triples_factory, | |||||
| batch_size=training_params.batch_size, | |||||
| shuffle=True, | |||||
| ) | |||||
| self.train_iterator = self._data_iterator(self.train_loader) | |||||
| seed = 1234 | |||||
| val_sampler = DeterministicRandomSampler( | |||||
| len(self.val_dataset), | |||||
| sample_size=self.training_params.validation_sample_size, | |||||
| seed=seed, | |||||
| ) | |||||
| self.valid_loader = DataLoader( | |||||
| self.val_dataset, | |||||
| batch_size=training_params.validation_batch_size, | |||||
| shuffle=False, | |||||
| num_workers=training_params.num_workers, | |||||
| sampler=val_sampler, | |||||
| ) | |||||
| self.valid_iterator = self._data_iterator(self.valid_loader) | |||||
| self.step = 0 | |||||
| self.log_every = common_params.log_every | |||||
| self.log_console_every = common_params.log_console_every | |||||
| self.save_dpath = common_params.save_dpath | |||||
| self.save_every = common_params.save_every | |||||
| self.checkpoint_manager = CheckpointManager( | |||||
| root_directory=self.save_dpath, | |||||
| run_name=common_params.run_name, | |||||
| load_only=common_params.evaluate_only, | |||||
| ) | |||||
| self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory | |||||
| if not common_params.evaluate_only: | |||||
| logger.info(f"Saving checkpoints to {self.ckpt_dpath}") | |||||
| if self.common_params.load_path: | |||||
| self.load() | |||||
| self.writer = None | |||||
| if not common_params.evaluate_only: | |||||
| run_name = self.checkpoint_manager.run_name | |||||
| self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}") | |||||
| logger.addHandler(TensorBoardHandler(self.writer)) | |||||
| def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None: | |||||
| """ | |||||
| Log a message to the console and TensorBoard. | |||||
| Args: | |||||
| message (str): The message to log. | |||||
| """ | |||||
| if console_msg is not None: | |||||
| logger.info(f"{console_msg}") | |||||
| if self.writer is not None: | |||||
| self.writer.add_scalar(tag, log_value, self.step) | |||||
| def save(self) -> None: | |||||
| save_dpath = self.ckpt_dpath | |||||
| save_dpath = Path(save_dpath) | |||||
| save_dpath.mkdir(parents=True, exist_ok=True) | |||||
| self.model_trainer.save(save_dpath, self.step) | |||||
| def load(self) -> None: | |||||
| load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path) | |||||
| load_fpath = Path(load_path) | |||||
| if load_fpath.exists(): | |||||
| self.step = self.model_trainer.load(load_fpath) | |||||
| logger.info(f"Model loaded from {load_fpath}.") | |||||
| else: | |||||
| msg = f"Load path {load_fpath} does not exist." | |||||
| raise FileNotFoundError(msg) | |||||
| def _data_iterator(self, loader: DataLoader) -> iter: | |||||
| """Yield batches of positive and negative triples.""" | |||||
| while True: | |||||
| yield from loader | |||||
| def train(self) -> None: | |||||
| """Run `n_steps` optimisation steps; return mean loss.""" | |||||
| data_batch = next(self.train_iterator) | |||||
| loss = self.model_trainer.train_step(data_batch, self.step) | |||||
| if self.step % self.log_console_every == 0: | |||||
| logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}") | |||||
| if self.step % self.log_every == 0: | |||||
| self.log("train/loss", loss) | |||||
| self.step += 1 | |||||
| def evaluate(self) -> dict[str, torch.Tensor]: | |||||
| ks = [1, 3, 10] | |||||
| metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks] | |||||
| evaluator = RankBasedEvaluator( | |||||
| metrics=metrics, | |||||
| filtered=True, | |||||
| batch_size=self.training_params.validation_batch_size, | |||||
| ) | |||||
| result = evaluator.evaluate( | |||||
| model=self.model_trainer.model, | |||||
| mapped_triples=self.val_dataset.triples, | |||||
| additional_filter_triples=[ | |||||
| self.train_dataset.triples, | |||||
| self.val_dataset.triples, | |||||
| ], | |||||
| batch_size=self.training_params.validation_batch_size, | |||||
| ) | |||||
| mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank") | |||||
| hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks} | |||||
| metrics = { | |||||
| "mrr": mrr, | |||||
| **hits_at, | |||||
| } | |||||
| for k, v in metrics.items(): | |||||
| self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}") | |||||
| # ───────────────────────── # | |||||
| # master schedule | |||||
| # ───────────────────────── # | |||||
| def run(self) -> None: | |||||
| """ | |||||
| Alternate training and evaluation until `total_steps` | |||||
| optimisation steps have been executed. | |||||
| """ | |||||
| total_steps = self.training_params.num_train_steps | |||||
| eval_every = self.training_params.eval_every | |||||
| while self.step < total_steps: | |||||
| self.train() | |||||
| if self.step % eval_every == 0: | |||||
| logger.info(f"Evaluating at step {self.step}...") | |||||
| self.evaluate() | |||||
| if self.step % self.save_every == 0: | |||||
| logger.info(f"Saving model at step {self.step}...") | |||||
| self.save() | |||||
| logger.info("Training complete.") | |||||
| def reset(self) -> None: | |||||
| """ | |||||
| Reset the trainer state, including the model trainer and data iterators. | |||||
| """ | |||||
| self.model_trainer.reset() | |||||
| self.train_iterator = self._data_iterator(self.train_loader) | |||||
| self.valid_iterator = self._data_iterator(self.valid_loader) | |||||
| self.step = 0 | |||||
| logger.info("Trainer state has been reset.") |