# KGEvaluation |
from __future__ import annotations | |||||
import argparse | |||||
import logging | |||||
import torch | |||||
from data import YAGO310Dataset | |||||
# from metrics.crec_modifier import tune_crec_parallel_edits | |||||
from data.kg_dataset import KGDataset | |||||
from metrics.crec_radius_sample import find_sample | |||||
from metrics.wlcrec import WLCREC | |||||
from tools import get_pretty_logger | |||||
logging = get_pretty_logger(__name__) | |||||
# -------------------------------------------------------------------- | |||||
# CLI | |||||
# -------------------------------------------------------------------- | |||||
def main(): | |||||
p = argparse.ArgumentParser() | |||||
p.add_argument("--depth", type=int, default=5) | |||||
p.add_argument("--lower", type=float, default=0.25) | |||||
p.add_argument("--upper", type=float, default=0.26) | |||||
p.add_argument("--seed", type=int, default=0) | |||||
p.add_argument("--save", type=str, default="fb15k_crec_bi.pt") | |||||
args = p.parse_args() | |||||
# ds = FB15KDataset() | |||||
ds = YAGO310Dataset() | |||||
triples = ds.triples | |||||
n_ent, n_rel = ds.num_entities, ds.num_relations | |||||
logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples)) | |||||
# wlcrec = WLCREC(ds) | |||||
# c = wlcrec.compute(5)[-2] # C_NWLEC | |||||
# colours = wlcrec.wl_colours(args.depth)[-1] | |||||
# tuned = tune_crec( | |||||
# wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed | |||||
# ) | |||||
# tuned = tune_crec_parallel_edits( | |||||
# triples, | |||||
# n_ent, | |||||
# n_rel, | |||||
# depth=args.depth, | |||||
# lo=args.lower, | |||||
# hi=args.upper, | |||||
# seed=args.seed, | |||||
# ) | |||||
# tuned = fast_tune_crec( | |||||
# triples.tolist(), | |||||
# colours, | |||||
# n_ent, | |||||
# n_rel, | |||||
# depth=args.depth, | |||||
# c=c, | |||||
# lo=args.lower, | |||||
# hi=args.upper, | |||||
# max_iters=1000, | |||||
# max_workers=10, # Adjust number of workers as needed | |||||
# ) | |||||
tuned = find_sample( | |||||
triples.tolist(), | |||||
n_ent, | |||||
n_rel, | |||||
depth=args.depth, | |||||
ratio=0.1, # Adjust ratio as needed | |||||
radius=2, # Adjust radius as needed | |||||
lower=args.lower, | |||||
upper=args.upper, | |||||
max_tries=1000, | |||||
seed=58, | |||||
) | |||||
tuned_ds = KGDataset( | |||||
triples_factory=None, | |||||
triples=torch.tensor(tuned, dtype=torch.long), | |||||
num_entities=n_ent, | |||||
num_relations=n_rel, | |||||
split="train", | |||||
) | |||||
tuned_wlcrec = WLCREC(tuned_ds) | |||||
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5) | |||||
print( | |||||
f"\nTuned CREC results (H={5})", | |||||
f"\n • avg WLEC : {entropy:.6f} nats", | |||||
f"\n • C_ratio : {c_ratio:.6f}", | |||||
f"\n • C_NWLEC : {c_nwlec:.6f}", | |||||
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||||
f"\n • D_ratio : {d_ratio:.6f}", | |||||
f"\n • D_NWLEC : {d_nwlec:.6f}", | |||||
) | |||||
torch.save(torch.tensor(tuned, dtype=torch.long), args.save) | |||||
logging.info("Saved tuned triples → %s", args.save) | |||||
if __name__ == "__main__": | |||||
main() |
_target_: tools.params.CommonParams | |||||
model_name: "default_model" | |||||
project_name: "my_project" | |||||
run_name: ... | |||||
# Where to write all outputs (checkpoints / tensorboard logs / etc.) | |||||
save_dpath: "./checkpoints" | |||||
save_every: 200 | |||||
# If you want to resume from a checkpoint, set load_dpath to a string path; | |||||
# otherwise leave it null or empty. | |||||
load_path: null | |||||
# Default log directory | |||||
log_dir: "logs" | |||||
log_every: 10 | |||||
log_console_every: 10 | |||||
evaluate_only: false |
defaults: | |||||
- common: common | |||||
- model: model | |||||
- data: data | |||||
- training: training | |||||
hydra: | |||||
run: | |||||
dir: "./outputs" |
_target_: data.wn18rr_dataset.WN18RRDataset | |||||
split: train |
train: | |||||
_target_: data.fb15k.FB15KDataset | |||||
split: train | |||||
valid: | |||||
_target_: data.fb15k.FB15KDataset | |||||
split: valid |
train: | |||||
_target_: data.wn18.WN18Dataset | |||||
split: train | |||||
valid: | |||||
_target_: data.wn18.WN18Dataset | |||||
split: valid |
train: | |||||
_target_: data.wn18.WN18RRDataset | |||||
split: train | |||||
valid: | |||||
_target_: data.wn18.WN18RRDataset | |||||
split: valid |
train: | |||||
_target_: data.yago3_10.YAGO310Dataset | |||||
split: train | |||||
valid: | |||||
_target_: data.yago3_10.YAGO310Dataset | |||||
split: valid | |||||
num_entities: 100 | |||||
num_relations: 100 | |||||
dim: 100 |
defaults: | |||||
- model | |||||
- _self_ | |||||
_target_: models.translation.trans_e.TransE | |||||
dim: 200 | |||||
sf_norm: 1 | |||||
p_norm: true |
defaults: | |||||
- model | |||||
- _self_ | |||||
_target_: models.translation.trans_h.TransH | |||||
dim: 200 | |||||
sf_norm: 1 | |||||
p_norm: true | |||||
p_norm_value: 2 | |||||
margin: 1.0 |
defaults: | |||||
- model | |||||
- _self_ | |||||
_target_: models.translation.trans_r.TransR | |||||
entity_dim: 200 | |||||
relation_dim: 30 | |||||
sf_norm: 1 | |||||
p_norm: true | |||||
p_norm_value: 2 | |||||
margin: 1.0 |
_target_: tools.params.TrainingParams | |||||
lr: 0.005 | |||||
min_lr: 0.0001 | |||||
weight_decay: 0.0 | |||||
t0: 50 | |||||
lr_step: 10 | |||||
gamma: 0.95 | |||||
batch_size: 8192 | |||||
num_workers: 3 | |||||
seed: 42 | |||||
num_train_steps: 10000 | |||||
eval_every: 100 | |||||
validation_sample_size: 1000 | |||||
validation_batch_size: 64 | |||||
model_trainer: | |||||
_target_: training.model_trainers.model_trainer_base.ModelTrainerBase |
defaults: | |||||
- training | |||||
- _self_ | |||||
model_trainer: | |||||
_target_: training.model_trainers.translation.trans_e_trainer.TransETrainer | |||||
margin: 1.0 | |||||
n_negative: 5 | |||||
regul_rate: 0.0 | |||||
loss_fn: "margin" |
defaults: | |||||
- training | |||||
- _self_ | |||||
model_trainer: | |||||
_target_: training.model_trainers.translation.trans_r_trainer.TransRTrainer | |||||
margin: 1.0 | |||||
n_negative: 10 | |||||
regul_rate: 0.0 | |||||
loss_fn: "margin" |
defaults: | |||||
- config | |||||
- override common: common | |||||
- override data: fb15k | |||||
- override model: trans_e | |||||
- override training: trans_e_trainer | |||||
- _self_ | |||||
training: | |||||
model_trainer: | |||||
loss_fn: "margin" | |||||
common: | |||||
run_name: "TransE_FB15K" | |||||
evaluate_only: true | |||||
load_path: (2, 1800) |
defaults: | |||||
- config | |||||
- override common: common | |||||
- override data: wn18 | |||||
- override model: trans_e | |||||
- override training: trans_e_trainer | |||||
- _self_ | |||||
training: | |||||
model_trainer: | |||||
loss_fn: "margin" | |||||
common: | |||||
run_name: "TransE_WN18" | |||||
evaluate_only: true | |||||
load_path: (82, 3000) |
defaults: | |||||
- config | |||||
- override common: common | |||||
- override data: wn18rr | |||||
- override model: trans_e | |||||
- override training: trans_e_trainer | |||||
- _self_ | |||||
training: | |||||
model_trainer: | |||||
loss_fn: "margin" | |||||
common: | |||||
run_name: "TransE_WN18rr" | |||||
evaluate_only: true | |||||
load_path: (2, 6000) |
defaults: | |||||
- config | |||||
- override common: common | |||||
- override data: yago3_10 | |||||
- override model: trans_e | |||||
- override training: trans_e_trainer | |||||
- _self_ | |||||
training: | |||||
model_trainer: | |||||
loss_fn: "margin" | |||||
common: | |||||
run_name: "TransE_YAGO310" | |||||
evaluate_only: true | |||||
load_path: (1, 2600) |
defaults: | |||||
- config | |||||
- override common: common | |||||
- override data: wn18 | |||||
- override model: trans_r | |||||
- override training: trans_r_trainer | |||||
- _self_ | |||||
training: | |||||
model_trainer: | |||||
loss_fn: "margin" | |||||
common: | |||||
run_name: "TransR_WN18" | |||||
# evaluate_only: true | |||||
# load_path: (1, 1000) |
from .fb15k import FB15KDataset | |||||
from .wn18 import WN18Dataset, WN18RRDataset | |||||
from .yago3_10 import YAGO310Dataset | |||||
from .hationet import HetionetDataset | |||||
from .open_bio_link import OpenBioLinkDataset | |||||
from .openke_wiki import OpenKEWikiDataset |
from typing import List, Tuple | |||||
import torch | |||||
from pykeen.datasets import FB15k | |||||
from .kg_dataset import KGDataset | |||||
class FB15KDataset(KGDataset): | |||||
"""A wrapper around PyKeen's FB15k dataset producing KGDataset instances.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
dataset = FB15k() | |||||
# if split == "train": | |||||
if True: | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'" | |||||
raise ValueError(msg) | |||||
custom_triples = torch.load( | |||||
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt", | |||||
map_location=torch.device("cpu"), | |||||
).to(torch.long) | |||||
# .tolist() | |||||
tf = tf.clone_and_exchange_triples( | |||||
custom_triples, | |||||
) | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=dataset.num_entities, | |||||
num_relations=dataset.num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
return torch.tensor(super().__getitem__(idx), dtype=torch.long) |
from typing import List, Tuple | |||||
import torch | |||||
from pykeen.datasets import Hetionet | |||||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
class HetionetDataset(KGDataset): | |||||
"""A wrapper around PyKeen's Hetionet dataset that produces KGDataset instances.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
""" | |||||
Initialize the WN18RR wrapper. | |||||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||||
""" | |||||
# Load the PyKeen WN18 dataset | |||||
dataset = Hetionet() | |||||
# Select the appropriate TriplesFactory based on the requested split | |||||
if split == "train": | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
# Convert each row into a Python tuple of ints | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
# Get the total number of entities and relations in the entire dataset | |||||
num_entities = dataset.num_entities | |||||
num_relations = dataset.num_relations | |||||
# Initialize the base KGDataset with the selected triples | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=num_entities, | |||||
num_relations=num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
data = super().__getitem__(idx) | |||||
# Convert the tuple to a torch tensor | |||||
return torch.tensor(data, dtype=torch.long) |
"""PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||||
from pathlib import Path | |||||
from typing import Dict, List, Tuple | |||||
import torch | |||||
from pykeen.typing import MappedTriples | |||||
from torch.utils.data import Dataset | |||||
class KGDataset(Dataset): | |||||
"""PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||||
def __init__( | |||||
self, | |||||
triples_factory: MappedTriples, | |||||
triples: List[Tuple[int, int, int]], | |||||
num_entities: int, | |||||
num_relations: int, | |||||
split: str = "train", | |||||
) -> None: | |||||
super().__init__() | |||||
self.triples_factory = triples_factory | |||||
self.triples = torch.tensor(triples, dtype=torch.long) | |||||
self.num_entities = num_entities | |||||
self.num_relations = num_relations | |||||
self.split = split | |||||
def __len__(self) -> int: | |||||
"""Return the number of triples in the dataset.""" | |||||
return len(self.triples) | |||||
def __getitem__(self, idx: int) -> Tuple[int, int, int]: | |||||
"""Return the indexed triple at the given index.""" | |||||
return self.triples[idx] | |||||
def entities(self) -> torch.Tensor: | |||||
"""Return a list of all entities in the dataset.""" | |||||
return torch.arange(self.num_entities) | |||||
def relations(self) -> torch.Tensor: | |||||
"""Return a list of all relations in the dataset.""" | |||||
return torch.arange(self.num_relations) | |||||
def __repr__(self) -> str: | |||||
return ( | |||||
f"{self.__class__.__name__}(split={self.split!r}, " | |||||
f"num_triples={len(self.triples)}, " | |||||
f"num_entities={self.num_entities}, " | |||||
f"num_relations={self.num_relations})" | |||||
) | |||||
# ---------------------------- Helper functions -------------------------------- | |||||
def _map_to_ids( | |||||
triples: List[Tuple[str, str, str]], | |||||
) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]: | |||||
ent2id: Dict[str, int] = {} | |||||
rel2id: Dict[str, int] = {} | |||||
mapped: List[Tuple[int, int, int]] = [] | |||||
def _id(d: Dict[str, int], k: str) -> int: | |||||
d.setdefault(k, len(d)) | |||||
return d[k] | |||||
for h, r, t in triples: | |||||
mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t))) | |||||
return mapped, ent2id, rel2id | |||||
def load_kg_from_files( | |||||
root: str, splits: tuple[str] = ("train", "valid", "test") | |||||
) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]: | |||||
root = Path(root) | |||||
split_raw, all_raw = {}, [] | |||||
for s in splits: | |||||
lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()] | |||||
split_raw[s] = lines | |||||
all_raw.extend(lines) | |||||
mapped, ent2id, rel2id = _map_to_ids(all_raw) | |||||
datasets, start = {}, 0 | |||||
for s in splits: | |||||
end = start + len(split_raw[s]) | |||||
datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s) | |||||
start = end | |||||
return datasets, ent2id, rel2id |
from typing import List, Tuple | |||||
import torch | |||||
from pykeen.datasets import OpenBioLink | |||||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
class OpenBioLinkDataset(KGDataset): | |||||
"""A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
""" | |||||
Initialize the WN18RR wrapper. | |||||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||||
""" | |||||
# Load the PyKeen WN18 dataset | |||||
dataset = OpenBioLink() | |||||
# Select the appropriate TriplesFactory based on the requested split | |||||
if split == "train": | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
# Convert each row into a Python tuple of ints | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
# Get the total number of entities and relations in the entire dataset | |||||
num_entities = dataset.num_entities | |||||
num_relations = dataset.num_relations | |||||
# Initialize the base KGDataset with the selected triples | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=num_entities, | |||||
num_relations=num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
data = super().__getitem__(idx) | |||||
# Convert the tuple to a torch tensor | |||||
return torch.tensor(data, dtype=torch.long) |
from typing import List, Tuple | |||||
import torch | |||||
from pykeen.datasets import Wikidata5M | |||||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
class OpenKEWikiDataset(KGDataset): | |||||
"""A wrapper around PyKeen's OpenKE-Wiki (WikiN3) dataset that produces KGDataset instances.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
""" | |||||
Initialize the WN18RR wrapper. | |||||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||||
""" | |||||
# Load the PyKeen WN18 dataset | |||||
dataset = Wikidata5M() | |||||
# Select the appropriate TriplesFactory based on the requested split | |||||
if split == "train": | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
# Convert each row into a Python tuple of ints | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
# Get the total number of entities and relations in the entire dataset | |||||
num_entities = dataset.num_entities | |||||
num_relations = dataset.num_relations | |||||
# Initialize the base KGDataset with the selected triples | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=num_entities, | |||||
num_relations=num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
data = super().__getitem__(idx) | |||||
# Convert the tuple to a torch tensor | |||||
return torch.tensor(data, dtype=torch.long) |
from typing import List, Tuple | |||||
import torch | |||||
from pykeen.datasets import WN18, WN18RR | |||||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||||
class WN18Dataset(KGDataset): | |||||
"""A wrapper around PyKeen's WN18 dataset that produces KGDataset instances.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
""" | |||||
Initialize the WN18RR wrapper. | |||||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||||
""" | |||||
# Load the PyKeen WN18 dataset | |||||
dataset = WN18() | |||||
# Select the appropriate TriplesFactory based on the requested split | |||||
if split == "train": | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
# Convert each row into a Python tuple of ints | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
# Get the total number of entities and relations in the entire dataset | |||||
num_entities = dataset.num_entities | |||||
num_relations = dataset.num_relations | |||||
# Initialize the base KGDataset with the selected triples | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=num_entities, | |||||
num_relations=num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
data = super().__getitem__(idx) | |||||
# Convert the tuple to a torch tensor | |||||
return torch.tensor(data, dtype=torch.long) | |||||
# Assuming KGDataset is defined in the same module or already imported: | |||||
# from your_module import KGDataset | |||||
class WN18RRDataset(KGDataset): | |||||
"""A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
""" | |||||
Initialize the WN18RR wrapper. | |||||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||||
""" | |||||
# Load the PyKeen WN18RR dataset | |||||
dataset = WN18RR() | |||||
# Select the appropriate TriplesFactory based on the requested split | |||||
if split == "train": | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||||
# Convert each row into a Python tuple of ints | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
# Get the total number of entities and relations in the entire dataset | |||||
num_entities = dataset.num_entities | |||||
num_relations = dataset.num_relations | |||||
# Initialize the base KGDataset with the selected triples | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=num_entities, | |||||
num_relations=num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
data = super().__getitem__(idx) | |||||
# Convert the tuple to a torch tensor | |||||
return torch.tensor(data, dtype=torch.long) |
from typing import List, Tuple | |||||
import torch | |||||
from pykeen.datasets import PathDataset | |||||
from .kg_dataset import KGDataset | |||||
class YAGO310Dataset(KGDataset): | |||||
"""Wrapper around PyKeen's YAGO3-10 dataset.""" | |||||
def __init__(self, split: str = "train") -> None: | |||||
# dataset = YAGO310() | |||||
dataset = PathDataset( | |||||
training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt", | |||||
testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt", | |||||
validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt", | |||||
) | |||||
if split == "train": | |||||
# if True: | |||||
tf = dataset.training | |||||
elif split in ("valid", "val", "validation"): | |||||
tf = dataset.validation | |||||
elif split in ("test", "testing"): | |||||
tf = dataset.testing | |||||
else: | |||||
msg = f"Unrecognized split '{split}'" | |||||
raise ValueError(msg) | |||||
custom_triples = torch.load( | |||||
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt", | |||||
map_location=torch.device("cpu"), | |||||
).to(torch.long) | |||||
# .tolist() | |||||
# get a random sample of size 5000 | |||||
# seed = torch.seed() # For reproducibility | |||||
# torch.manual_seed(seed) | |||||
# torch.cuda.manual_seed(seed) | |||||
# if custom_triples.shape[0] > 5000: | |||||
# custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]] | |||||
tf = tf.clone_and_exchange_triples( | |||||
custom_triples, | |||||
) | |||||
# triples = tf.mapped_triples | |||||
# # get a random sample of size 5000 | |||||
# if triples.shape[0] > 5000: | |||||
# indices = torch.randperm(triples.shape[0])[:5000] | |||||
# triples = triples[indices] | |||||
# tf = tf.clone_and_exchange_triples(triples) | |||||
triples_list: List[Tuple[int, int, int]] = [ | |||||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||||
] | |||||
super().__init__( | |||||
triples_factory=tf, | |||||
triples=triples_list, | |||||
num_entities=dataset.num_entities, | |||||
num_relations=dataset.num_relations, | |||||
split=split, | |||||
) | |||||
def __getitem__(self, idx: int) -> torch.Tensor: | |||||
return torch.tensor(super().__getitem__(idx), dtype=torch.long) |
from data import ( | |||||
FB15KDataset, | |||||
WN18Dataset, | |||||
WN18RRDataset, | |||||
YAGO310Dataset, | |||||
) | |||||
from metrics.wlcrec import WLCREC | |||||
def main() -> None: | |||||
datasets = { | |||||
"YAGO3-10": YAGO310Dataset(split="train"), | |||||
# "WN18": WN18Dataset(split="train"), | |||||
# "WN18RR": WN18RRDataset(split="train"), | |||||
# "FB15K": FB15KDataset(split="train"), | |||||
# "Hetionet": HetionetDataset(split="train"), | |||||
# "OpenBioLink": OpenBioLinkDataset(split="train"), | |||||
# "OpenKEWiki": OpenKEWikiDataset(split="train"), | |||||
} | |||||
for name, dataset in datasets.items(): | |||||
wl_crec = WLCREC(dataset) | |||||
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = wl_crec.compute(H=20) | |||||
print( | |||||
f"\nDataset: {name}", | |||||
f"\nResults (H={20})", | |||||
f"\n • avg WLEC : {entropy:.6f} nats", | |||||
f"\n • C_ratio : {c_ratio:.6f}", | |||||
f"\n • C_NWLEC : {c_nwlec:.6f}", | |||||
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||||
f"\n • D_ratio : {d_ratio:.6f}", | |||||
f"\n • D_NWLEC : {d_nwlec:.6f}", | |||||
sep="", | |||||
) | |||||
if __name__ == "__main__": | |||||
main() |
from copy import deepcopy | |||||
import hydra | |||||
import lovely_tensors as lt | |||||
import torch | |||||
from hydra.utils import instantiate | |||||
from omegaconf import DictConfig | |||||
from metrics.c_swklf import CSWKLF | |||||
from tools import get_pretty_logger | |||||
# (Import whatever Trainer or runner you have) | |||||
from training.trainer import Trainer | |||||
logger = get_pretty_logger(__name__) | |||||
lt.monkey_patch() | |||||
@hydra.main(config_path="configs", config_name="config") | |||||
def main(cfg: DictConfig) -> None: | |||||
# Detect CUDA | |||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||||
common_params = cfg.common | |||||
training_params = deepcopy(cfg.training) | |||||
# drop the model_trainer from cfg.trainer | |||||
del training_params.model_trainer | |||||
common_params = instantiate(common_params) | |||||
training_params = instantiate(training_params) | |||||
# dataset instantiation | |||||
train_dataset = instantiate(cfg.data.train) | |||||
val_dataset = instantiate(cfg.data.valid) | |||||
model = instantiate( | |||||
cfg.model, | |||||
num_entities=train_dataset.num_entities, | |||||
num_relations=train_dataset.num_relations, | |||||
triples_factory=train_dataset.triples_factory, | |||||
device=device, | |||||
) | |||||
model = model.to(device) | |||||
model_trainer = instantiate( | |||||
cfg.training.model_trainer, | |||||
model=model, | |||||
training_params=training_params, | |||||
triples_factory=train_dataset.triples_factory, | |||||
mapped_triples=train_dataset.triples, | |||||
device=model.device, | |||||
) | |||||
# Instantiate the Trainer with the model_trainer | |||||
trainer = Trainer( | |||||
train_dataset=train_dataset, | |||||
val_dataset=val_dataset, | |||||
model_trainer=model_trainer, | |||||
common_params=common_params, | |||||
training_params=training_params, | |||||
device=model.device, | |||||
) | |||||
if cfg.common.evaluate_only: | |||||
trainer.evaluate() | |||||
c_swklf_metric = CSWKLF( | |||||
dataset=val_dataset, | |||||
model=model, | |||||
) | |||||
c_swklf_value = c_swklf_metric.compute() | |||||
logger.info(f"CSWKLF Metric Value: {c_swklf_value}") | |||||
else: | |||||
trainer.run() | |||||
if __name__ == "__main__": | |||||
main() |
from __future__ import annotations | |||||
from abc import ABC, abstractmethod | |||||
from typing import TYPE_CHECKING, Any, List, Tuple | |||||
if TYPE_CHECKING: | |||||
from data.kg_dataset import KGDataset | |||||
class BaseMetric(ABC): | |||||
"""Base class for metrics that compute Weisfeiler-Lehman Entropy | |||||
Complexity (WLEC) or similar metrics. | |||||
""" | |||||
def __init__(self, dataset: KGDataset) -> None: | |||||
self.dataset = dataset | |||||
self.num_entities = dataset.num_entities | |||||
self.out_adj, self.in_adj = self._build_adjacency_lists() | |||||
def _build_adjacency_lists( | |||||
self, | |||||
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]: | |||||
"""Create *in* and *out* adjacency lists from mapped triples. | |||||
Parameters | |||||
---------- | |||||
triples: | |||||
Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``). | |||||
num_entities: | |||||
Total number of entities. Needed to allocate adjacency containers. | |||||
Returns | |||||
------- | |||||
out_adj, in_adj | |||||
Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element | |||||
``in_adj[v]`` is a list ``[(r, h), ...]``. | |||||
""" | |||||
triples = self.dataset.triples # (m, 3) | |||||
num_entities = self.num_entities | |||||
out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||||
in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||||
for h, r, t in triples.tolist(): | |||||
out_adj[h].append((r, t)) | |||||
in_adj[t].append((r, h)) | |||||
return out_adj, in_adj | |||||
@abstractmethod | |||||
def compute( | |||||
self, | |||||
*args, | |||||
**kwargs, | |||||
) -> Any: ... |
from __future__ import annotations | |||||
import math | |||||
from collections import defaultdict | |||||
from dataclasses import dataclass | |||||
from typing import TYPE_CHECKING, Callable, Iterable | |||||
import torch | |||||
from torch import Tensor | |||||
from .base_metric import BaseMetric | |||||
if TYPE_CHECKING: | |||||
from data.kg_dataset import KGDataset | |||||
from models.base_model import ERModel | |||||
@torch.no_grad() | |||||
def _log_prob_distribution( | |||||
fn: Callable[[Tensor], Tensor], | |||||
query: Tensor, | |||||
batch_size: int, | |||||
) -> Iterable[Tensor]: | |||||
""" | |||||
Yield *log*-probability batches for an arbitrary distribution helper. | |||||
Parameters | |||||
---------- | |||||
fn : | |||||
A bound method of *model* returning **log**-probabilities, e.g. | |||||
``model.tail_distribution(..., log=True)``. | |||||
query : | |||||
LongTensor of shape (N, 2) containing the query IDs that `fn` expects. | |||||
batch_size : | |||||
Mini-batch size. Choose according to available GPU/CPU memory. | |||||
""" | |||||
for i in range(0, len(query), batch_size): | |||||
yield fn(query[i : i + batch_size]) # (B, |Candidates|) | |||||
def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor: | |||||
""" | |||||
KL(q || p) where q is uniform over *true_index_sets*. | |||||
Parameters | |||||
---------- | |||||
log_p : | |||||
Tensor of shape (B, N) — log-probabilities from the model. | |||||
true_index_sets : | |||||
List (length B). Element *b* contains the list of **indices** | |||||
that are correct for row *b*. | |||||
Returns | |||||
------- | |||||
kl : Tensor, shape (B,) — KL divergence per row (natural log). | |||||
""" | |||||
rows: list[Tensor] = [] | |||||
for lp_row, idx in zip(log_p, true_index_sets): | |||||
k = len(idx) | |||||
rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p] | |||||
return torch.tensor(rows, device=log_p.device) | |||||
@dataclass(slots=True) | |||||
class _Accum: | |||||
"""Simple accumulator for ∑ value and ∑ weight.""" | |||||
tot_v: float = 0.0 | |||||
tot_w: float = 0.0 | |||||
def update(self, value: float, weight: float) -> None: | |||||
self.tot_v += value * weight | |||||
self.tot_w += weight | |||||
@property | |||||
def mean(self) -> float: | |||||
return self.tot_v / self.tot_w if self.tot_w else float("nan") | |||||
class CSWKLF(BaseMetric): | |||||
""" | |||||
C-SWKLF metric for evaluating knowledge graph embeddings. | |||||
This metric is used to evaluate the quality of knowledge graph embeddings | |||||
by computing the C-SWKLF score. | |||||
""" | |||||
def __init__(self, dataset: KGDataset, model: ERModel) -> None: | |||||
super().__init__(dataset) | |||||
self.model = model | |||||
@torch.no_grad() | |||||
def compute( | |||||
self, | |||||
alpha: float = 1 / 3, | |||||
beta: float = 1 / 3, | |||||
gamma: float = 1 / 3, | |||||
batch_size: int = 1024, | |||||
slice_size: int | None = 2048, | |||||
device: torch.device | str | None = None, | |||||
) -> float: | |||||
""" | |||||
Compute *Comprehensive Structural-Weighted KL-Fitness*. | |||||
Parameters | |||||
---------- | |||||
model : | |||||
A trained PyKEEN ERModel **with** the three distribution helpers. | |||||
kg : | |||||
`pykeen.triples.KGInfo` instance providing `mapped_triples`. | |||||
alpha, beta, gamma : | |||||
Non-negative weights for the three query types, default 1/3 each. | |||||
Must sum to *1*. | |||||
batch_size : | |||||
Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM. | |||||
slice_size : | |||||
Forwarded to `tail_distribution` / `head_distribution` / | |||||
`relation_distribution`. `None` disables slicing. | |||||
device : | |||||
Where to do the computation. `None` ⇒ first param's device. | |||||
Returns | |||||
------- | |||||
score : float in (0, 1] | |||||
Higher = model closer to empirical distributions across *all* directions. | |||||
""" | |||||
# --------------------------------------------------------------------- # | |||||
# Preparation # | |||||
# --------------------------------------------------------------------- # | |||||
assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1." | |||||
model_device = next(self.model.parameters()).device | |||||
if device is None: | |||||
device = model_device | |||||
triples = self.dataset.triples.to(device) # (|T|, 3) | |||||
heads, rels, tails = triples.t() # (|T|,) | |||||
# Build index structures -------------------------------------------------- | |||||
tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list) | |||||
heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list) | |||||
rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list) | |||||
for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()): | |||||
tails_by_hr[(h, r)].append(t) | |||||
heads_by_rt[(r, t)].append(h) | |||||
rels_by_ht[(h, t)].append(r) | |||||
# Structural entropies & query lists --------------------------------------- | |||||
H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r | |||||
H_head: dict[int, _Accum] = defaultdict(_Accum) | |||||
tail_queries: list[tuple[int, int]] = [] # (h,r) | |||||
head_queries: list[tuple[int, int]] = [] # (r,t) | |||||
rel_queries: list[tuple[int, int]] = [] # (h,t) | |||||
# --- tails -------------------------------------------------------------- | |||||
for (h, r), ts in tails_by_hr.items(): | |||||
k = len(ts) | |||||
H_tail[r].update(math.log(k), 1.0) | |||||
tail_queries.append((h, r)) | |||||
# --- heads -------------------------------------------------------------- | |||||
for (r, t), hs in heads_by_rt.items(): | |||||
k = len(hs) | |||||
H_head[r].update(math.log(k), 1.0) | |||||
head_queries.append((r, t)) | |||||
# --- relations ---------------------------------------------------------- | |||||
H_rel_accum = _Accum() | |||||
for (h, t), rs in rels_by_ht.items(): | |||||
k = len(rs) | |||||
H_rel_accum.update(math.log(k), 1.0) | |||||
rel_queries.append((h, t)) | |||||
H_struct_tail = {r: acc.mean for r, acc in H_tail.items()} | |||||
H_struct_head = {r: acc.mean for r, acc in H_head.items()} | |||||
# H_struct_rel = H_rel_accum.mean # scalar | |||||
# --------------------------------------------------------------------- # | |||||
# KL divergences # | |||||
# --------------------------------------------------------------------- # | |||||
def _avg_kl_per_relation( | |||||
queries: list[tuple[int, int]], | |||||
true_idx_map: dict[tuple[int, int], list[int]], | |||||
distribution_fn: Callable[[Tensor], Tensor], | |||||
num_relations: bool = False, # switch to know output size | |||||
) -> dict[int, float]: | |||||
""" | |||||
Compute *average* KL(q || p) **per relation**. | |||||
Works for the tail & head conditionals. | |||||
""" | |||||
kl_sum: dict[int, float] = defaultdict(float) | |||||
count: dict[int, int] = defaultdict(int) | |||||
query_tensor = torch.tensor(queries, dtype=torch.long, device=device) | |||||
for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size): | |||||
# slice_size handled inside distribution_fn | |||||
# log_p : (B, |Candidates|) | |||||
b = log_p.shape[0] | |||||
# Map back to current slice of queries | |||||
slice_queries = queries[:b] | |||||
queries[:] = queries[b:] # consume | |||||
true_sets = [true_idx_map[q] for q in slice_queries] | |||||
kl_row = _kl_uniform(log_p, true_sets) # (B,) | |||||
for (x, r), kl_val in zip(slice_queries, kl_row.tolist()): | |||||
rel_id = r if not num_relations else x | |||||
kl_sum[rel_id] += kl_val | |||||
count[rel_id] += 1 | |||||
return {r: kl_sum[r] / count[r] for r in kl_sum} | |||||
# --- tails -------------------------------------------------------------- | |||||
tail_queries_copy = tail_queries.copy() | |||||
D_tail = _avg_kl_per_relation( | |||||
tail_queries_copy, | |||||
tails_by_hr, | |||||
lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s), | |||||
) | |||||
S_tail = {r: math.exp(-d) for r, d in D_tail.items()} | |||||
# --- heads -------------------------------------------------------------- | |||||
head_queries_copy = head_queries.copy() | |||||
D_head = _avg_kl_per_relation( | |||||
head_queries_copy, | |||||
heads_by_rt, | |||||
lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s), | |||||
num_relations=True, | |||||
) | |||||
S_head = {r: math.exp(-d) for r, d in D_head.items()} | |||||
# --- relations (single global number) ---------------------------------- | |||||
D_rel_sum = 0.0 | |||||
for log_p in _log_prob_distribution( | |||||
lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s), | |||||
torch.tensor(rel_queries, dtype=torch.long, device=device), | |||||
batch_size, | |||||
): | |||||
slice_size_batch = log_p.shape[0] | |||||
slice_queries = rel_queries[:slice_size_batch] | |||||
rel_queries[:] = rel_queries[slice_size_batch:] | |||||
true_sets = [rels_by_ht[q] for q in slice_queries] | |||||
D_rel_sum += _kl_uniform(log_p, true_sets).sum().item() | |||||
D_rel = D_rel_sum / H_rel_accum.tot_w | |||||
S_rel = math.exp(-D_rel) | |||||
# --------------------------------------------------------------------- # | |||||
# Weighted aggregation → C-SWKLF # | |||||
# --------------------------------------------------------------------- # | |||||
def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float: | |||||
num = sum(weight[r] * score[r] for r in score) | |||||
den = sum(weight[r] for r in score) | |||||
return num / den if den else float("nan") | |||||
sw_tail = _weighted_mean(S_tail, H_struct_tail) | |||||
sw_head = _weighted_mean(S_head, H_struct_head) | |||||
# Relations: numerator & denominator cancel | |||||
sw_rel = S_rel | |||||
return alpha * sw_tail + beta * sw_head + gamma * sw_rel |
# from __future__ import annotations | |||||
# import math | |||||
# import os | |||||
# import random | |||||
# import threading | |||||
# from collections import defaultdict | |||||
# from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
# from typing import Dict, List, Tuple | |||||
# import torch | |||||
# from data.kg_dataset import KGDataset | |||||
# from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in | |||||
# from tools import get_pretty_logger | |||||
# logger = get_pretty_logger(__name__) | |||||
# # -------------------------------------------------------------------- | |||||
# # Edge-editing primitives | |||||
# # -------------------------------------------------------------------- | |||||
# # -- additions -------------------- | |||||
# def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||||
# for _ in range(1000): | |||||
# h = rng.randrange(n_ent) | |||||
# t = rng.randrange(n_ent) | |||||
# if colours[h] != colours[t]: | |||||
# r = rng.randrange(n_rel) | |||||
# triples.append((h, r, t)) | |||||
# out_adj[h].append((r, t)) | |||||
# in_adj[t].append((r, h)) | |||||
# return | |||||
# def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||||
# σ = rng.choice(colours) | |||||
# τ = rng.choice(colours) | |||||
# h = colours.index(σ) | |||||
# t = colours.index(τ) | |||||
# triples.append((h, 0, t)) | |||||
# out_adj[h].append((0, t)) | |||||
# in_adj[t].append((0, h)) | |||||
# # -- removals -------------------- | |||||
# def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
# cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||||
# if cand: | |||||
# idx = rng.choice(cand) | |||||
# h, r, t = triples.pop(idx) | |||||
# out_adj[h].remove((r, t)) | |||||
# in_adj[t].remove((r, h)) | |||||
# def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
# sig_rel = defaultdict(set) | |||||
# for h, r, t in triples: | |||||
# σ, τ = colours[h], colours[t] | |||||
# sig_rel[(σ, τ)].add(r) | |||||
# cand = [] | |||||
# for i, (h, r, t) in enumerate(triples): | |||||
# if len(sig_rel[(colours[h], colours[t])]) == 1: | |||||
# cand.append(i) | |||||
# if cand: | |||||
# idx = rng.choice(cand) | |||||
# h, r, t = triples.pop(idx) | |||||
# out_adj[h].remove((r, t)) | |||||
# in_adj[t].remove((r, h)) | |||||
# def _search_worker( | |||||
# seed: int, | |||||
# triples_init: List[Tuple[int, int, int]], | |||||
# triples_factory, | |||||
# n_ent: int, | |||||
# n_rel: int, | |||||
# depth: int, | |||||
# lo: float, | |||||
# hi: float, | |||||
# max_iters: int, | |||||
# ) -> List[Tuple[int, int, int]] | None: | |||||
# """ | |||||
# Run the exact same hill‑climb that `tune_crec()` did, | |||||
# but entirely in this process. Return the edited triples | |||||
# once c falls in [lo, hi]; return None if we used up all iterations. | |||||
# """ | |||||
# rng = random.Random(seed) | |||||
# triples = triples_init.copy().tolist() | |||||
# for it in range(max_iters): | |||||
# # WL‑CREC is *only* recomputed every 1000 edits, exactly like before | |||||
# if it % 1000 == 0: | |||||
# dataset = KGDataset( | |||||
# triples_factory, | |||||
# triples=torch.tensor(triples, dtype=torch.long), | |||||
# num_entities=n_ent, | |||||
# num_relations=n_rel, | |||||
# ) | |||||
# wl = WLCREC(dataset) | |||||
# colours = wl.wl_colours(depth) | |||||
# *_, c, _ = wl.compute(H=5) # unchanged API | |||||
# if lo <= c <= hi: # success | |||||
# logger.info("[seed %d] hit %.4f in %d edits", seed, c, it) | |||||
# return triples | |||||
# # ---------- identical edit logic ---------- | |||||
# if c < lo: | |||||
# if rng.random() < 0.5: | |||||
# rem_det_edge(triples, colours, depth, rng) | |||||
# else: | |||||
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||||
# else: # c > hi | |||||
# if rng.random() < 0.5: | |||||
# rem_div_edge(triples, colours, depth, rng) | |||||
# else: | |||||
# add_det_edge(triples, colours, depth, n_rel, rng) | |||||
# # ------------------------------------------ | |||||
# return None # used up our budget | |||||
# # -------------------------------------------------------------------- | |||||
# # Unified tuner | |||||
# # -------------------------------------------------------------------- | |||||
# def tune_crec( | |||||
# wl_crec: WLCREC, | |||||
# n_ent: int, | |||||
# n_rel: int, | |||||
# depth: int, | |||||
# lo: float, | |||||
# hi: float, | |||||
# max_iters: int = 80_000, | |||||
# seed: int = 42, | |||||
# ): | |||||
# triples = wl_crec.dataset.triples.tolist() | |||||
# rng = random.Random(seed) | |||||
# for it in range(max_iters): | |||||
# print(f"\r[iter {it + 1:5d}] ", end="") | |||||
# if it % 1000 == 0: | |||||
# dataset = KGDataset( | |||||
# wl_crec.dataset.triples_factory, | |||||
# triples=torch.tensor(triples, dtype=torch.long), | |||||
# num_entities=n_ent, | |||||
# num_relations=n_rel, | |||||
# ) | |||||
# tmp_wl_crec = WLCREC(dataset) | |||||
# # colours = wl_colours(triples, n_ent, depth) | |||||
# colours = tmp_wl_crec.wl_colours(depth) | |||||
# _, _, _, _, c, _ = tmp_wl_crec.compute(H=5) | |||||
# if lo <= c <= hi: | |||||
# logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples)) | |||||
# return triples | |||||
# if c < lo: | |||||
# # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying | |||||
# if rng.random() < 0.5: | |||||
# rem_det_edge(triples, colours, depth, rng) | |||||
# else: | |||||
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||||
# # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic | |||||
# elif rng.random() < 0.5: | |||||
# rem_div_edge(triples, colours, depth, rng) | |||||
# else: | |||||
# add_det_edge(triples, colours, depth, n_rel, rng) | |||||
# if (it + 1) % 10000 == 1: | |||||
# logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples)) | |||||
# raise RuntimeError("Exceeded max iterations without hitting target band.") | |||||
# def _edit_batch( | |||||
# worker_id: int, | |||||
# n_edits: int, | |||||
# colours: List[int], | |||||
# crec: float, | |||||
# target_lo: float, | |||||
# target_hi: float, | |||||
# depth: int, | |||||
# n_ent: int, | |||||
# n_rel: int, | |||||
# seed_base: int, | |||||
# ): | |||||
# """ | |||||
# Perform `n_edits` topology modifications *locally* and return | |||||
# (added_triples, removed_triples) lists. | |||||
# """ | |||||
# rng = random.Random(seed_base + worker_id) | |||||
# local_added, local_removed = [], [] | |||||
# # local graph views: only sizes matter for edit selection | |||||
# for _ in range(n_edits): | |||||
# if crec < target_lo: # need ↑ WL-CREC | |||||
# if rng.random() < 0.5: # • remove deterministic | |||||
# # We cannot remove by index safely without the global list; | |||||
# # choose a *signature* and remember intention to delete: | |||||
# local_removed.append(("det", rng.random())) | |||||
# else: # • add diversifying | |||||
# h = rng.randrange(n_ent) | |||||
# t = rng.randrange(n_ent) | |||||
# while colours[h] == colours[t]: | |||||
# h = rng.randrange(n_ent) | |||||
# t = rng.randrange(n_ent) | |||||
# r = rng.randrange(n_rel) | |||||
# local_added.append((h, r, t)) | |||||
# else: # need ↓ WL-CREC | |||||
# if rng.random() < 0.5: # • remove diversifying | |||||
# local_removed.append(("div", rng.random())) | |||||
# else: # • add deterministic | |||||
# σ = rng.choice(colours) | |||||
# τ = rng.choice(colours) | |||||
# h = colours.index(σ) | |||||
# t = colours.index(τ) | |||||
# local_added.append((h, 0, t)) | |||||
# return local_added, local_removed | |||||
# def wl_colours(triples, n_ent, depth): | |||||
# # build adjacency once | |||||
# out_adj, in_adj = defaultdict(list), defaultdict(list) | |||||
# for h, r, t in triples: | |||||
# out_adj[h].append((r, t)) | |||||
# in_adj[t].append((r, h)) | |||||
# colours_rounds = [[0] * n_ent] # round-0 colours | |||||
# for h in range(1, depth + 1): | |||||
# prev = colours_rounds[-1] | |||||
# # 1) build textual signatures in parallel | |||||
# def sig(v): | |||||
# neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||||
# ("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||||
# ] | |||||
# neigh.sort() | |||||
# return (prev[v], tuple(neigh)) | |||||
# with ThreadPoolExecutor() as tpe: # cheap threads inside worker | |||||
# sigs = list(tpe.map(sig, range(n_ent))) | |||||
# # 2) assign deterministic colour IDs | |||||
# sig2id: Dict[Tuple, int] = {} | |||||
# next_round = [0] * n_ent | |||||
# fresh = 0 | |||||
# for v, sg in enumerate(sigs): | |||||
# cid = sig2id.setdefault(sg, fresh) | |||||
# if cid == fresh: | |||||
# fresh += 1 | |||||
# next_round[v] = cid | |||||
# colours_rounds.append(next_round) | |||||
# depth_colours = colours_rounds[-1] | |||||
# return depth_colours | |||||
# def _metric_worker(args): | |||||
# triples, n_ent, n_rel, depth = args | |||||
# dataset = KGDataset( | |||||
# triples_factory=None, # Not used in this context | |||||
# triples=torch.tensor(triples, dtype=torch.long), | |||||
# num_entities=n_ent, | |||||
# num_relations=n_rel, | |||||
# ) | |||||
# wl_crec = WLCREC(dataset) | |||||
# _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False) | |||||
# colours = wl_colours(triples, n_ent, depth) | |||||
# return c, colours | |||||
# def tune_crec_parallel_edits( | |||||
# triples_init, | |||||
# n_ent: int, | |||||
# n_rel: int, | |||||
# depth: int, | |||||
# target_lo: float, | |||||
# target_hi: float, | |||||
# max_iters: int = 80_000, | |||||
# metric_every: int = 100, | |||||
# n_workers: int = max(20, math.ceil(os.cpu_count() / 2)), | |||||
# seed: int = 42, | |||||
# ): | |||||
# # -------- shared mutable state (main thread owns it) -------- | |||||
# triples = triples_init.tolist() | |||||
# out_adj, in_adj = defaultdict(list), defaultdict(list) | |||||
# for h, r, t in triples: | |||||
# out_adj[h].append((r, t)) | |||||
# in_adj[t].append((r, h)) | |||||
# pool = ThreadPoolExecutor(max_workers=n_workers) | |||||
# metric_lock = threading.Lock() # exactly one metric at a time | |||||
# rng_global = random.Random(seed) | |||||
# # ----- first metric checkpoint ----- | |||||
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||||
# edit_budget_total = 0 | |||||
# for it in range(0, max_iters, metric_every): | |||||
# # ========================================================= | |||||
# # 1. PARALLEL EDIT STAGE (metric_every edits in total) | |||||
# # ========================================================= | |||||
# futures = [] | |||||
# edits_per_worker = metric_every // n_workers | |||||
# extra = metric_every % n_workers | |||||
# for wid in range(n_workers): | |||||
# n_edits = edits_per_worker + (1 if wid < extra else 0) | |||||
# futures.append( | |||||
# pool.submit( | |||||
# _edit_batch, | |||||
# wid, | |||||
# n_edits, | |||||
# colours, | |||||
# crec, | |||||
# target_lo, | |||||
# target_hi, | |||||
# depth, | |||||
# n_ent, | |||||
# n_rel, | |||||
# seed, | |||||
# ) | |||||
# ) | |||||
# # merge when workers finish | |||||
# for fut in as_completed(futures): | |||||
# added, removed_specs = fut.result() | |||||
# # --- apply additions immediately (cheap, conflict-free) | |||||
# for h, r, t in added: | |||||
# triples.append((h, r, t)) | |||||
# out_adj[h].append((r, t)) | |||||
# in_adj[t].append((r, h)) | |||||
# # --- apply removals: interpret the spec on *current* graph | |||||
# for typ, randv in removed_specs: | |||||
# if typ == "div": | |||||
# idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||||
# else: # 'det' | |||||
# sig_rel = defaultdict(set) | |||||
# for h, r, t in triples: | |||||
# sig_rel[(colours[h], colours[t])].add(r) | |||||
# idxs = [ | |||||
# i | |||||
# for i, (h, r, t) in enumerate(triples) | |||||
# if len(sig_rel[(colours[h], colours[t])]) == 1 | |||||
# ] | |||||
# if idxs: | |||||
# victim = idxs[int(randv * len(idxs))] | |||||
# h, r, t = triples.pop(victim) | |||||
# out_adj[h].remove((r, t)) | |||||
# in_adj[t].remove((r, h)) | |||||
# edit_budget_total += metric_every | |||||
# # ========================================================= | |||||
# # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised) | |||||
# # ========================================================= | |||||
# # if edit_budget_total % (10 * metric_every) == 0: | |||||
# if True: | |||||
# with metric_lock: | |||||
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||||
# logging.info( | |||||
# "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples) | |||||
# ) | |||||
# if target_lo <= crec <= target_hi: | |||||
# logging.info("Target band reached.") | |||||
# pool.shutdown(wait=True) | |||||
# return triples | |||||
# pool.shutdown(wait=True) | |||||
# raise RuntimeError("Exceeded max iteration budget without success.") | |||||
# from __future__ import annotations | |||||
# import logging | |||||
# import random | |||||
# from collections import defaultdict | |||||
# from concurrent.futures import ThreadPoolExecutor, as_completed | |||||
# from typing import List, Tuple | |||||
# import torch | |||||
# from data.kg_dataset import KGDataset | |||||
# from metrics.wlcrec import WLCREC | |||||
# from tools import get_pretty_logger | |||||
# logger = get_pretty_logger(__name__) | |||||
# # ------------ additions ------------ | |||||
# def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random): | |||||
# for _ in range(1000): | |||||
# h, t = rng.randrange(n_ent), rng.randrange(n_ent) | |||||
# if colours[h] != colours[t]: | |||||
# r = rng.randrange(n_rel) | |||||
# return ("add", (h, r, t)) | |||||
# return None # fell through – extremely rare | |||||
# def propose_add_det(colours, n_rel: int, rng: random.Random): | |||||
# σ, τ = rng.choice(colours), rng.choice(colours) | |||||
# h, t = colours.index(σ), colours.index(τ) | |||||
# return ("add", (h, 0, t)) # rel 0 = deterministic | |||||
# # ------------ removals ------------ | |||||
# def propose_rem_div(triples: List, colours, rng: random.Random): | |||||
# cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]] | |||||
# return ("rem", rng.choice(cand)) if cand else None | |||||
# def propose_rem_det(triples: List, colours, rng: random.Random): | |||||
# sig_rel = defaultdict(set) | |||||
# for h, r, t in triples: | |||||
# sig_rel[(colours[h], colours[t])].add(r) | |||||
# cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1] | |||||
# return ("rem", rng.choice(cand)) if cand else None | |||||
# def make_edit_proposal( | |||||
# triples_snapshot: List, | |||||
# colours_snapshot, | |||||
# c: float, | |||||
# lo: float, | |||||
# hi: float, | |||||
# n_ent: int, | |||||
# n_rel: int, | |||||
# depth: int, # still here for future use / signatures | |||||
# seed: int, | |||||
# ): | |||||
# """Return exactly one ('add' | 'rem', triple) proposal or None.""" | |||||
# rng = random.Random(seed) | |||||
# # -- decide which kind of edit we want, *given* the current c ------------ | |||||
# if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying) | |||||
# chooser = (propose_rem_det, propose_add_div) | |||||
# else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic) | |||||
# chooser = (propose_rem_div, propose_add_det) | |||||
# op = rng.choice(chooser) | |||||
# return ( | |||||
# op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng) | |||||
# if op.__name__.startswith("propose_add") | |||||
# else op(triples_snapshot, colours_snapshot, rng) | |||||
# ) | |||||
# def tune_crec_parallel_edits( | |||||
# triples: List, | |||||
# n_ent: int, | |||||
# n_rel: int, | |||||
# depth: int, | |||||
# lo: float, | |||||
# hi: float, | |||||
# *, | |||||
# max_iters: int = 80_000, | |||||
# edits_per_eval: int = 1000, # == old “if it % 1000 == 0” | |||||
# batch_size: int = 256, # how many proposals we farm out at once | |||||
# max_workers: int = 4, | |||||
# seed: int = 42, | |||||
# ) -> List: | |||||
# rng_global = random.Random(seed) | |||||
# triples = set(triples) # deduplicate if needed | |||||
# with ThreadPoolExecutor(max_workers=max_workers) as pool: | |||||
# proposal_seed = seed * 997 # deterministic but different stream | |||||
# # edit_counter = 0 | |||||
# # while edit_counter < max_iters: | |||||
# # # ----------------- expensive part (single‑thread) ---------------- | |||||
# # dataset = KGDataset( | |||||
# # triples_factory=None, # Not used in this context | |||||
# # triples=torch.tensor(triples, dtype=torch.long), | |||||
# # num_entities=n_ent, | |||||
# # num_relations=n_rel, | |||||
# # ) | |||||
# # tmp = WLCREC(dataset) | |||||
# # colours = tmp.wl_colours(depth) | |||||
# # *_, c, _ = tmp.compute(H=5) | |||||
# # # ----------------------------------------------------------------- | |||||
# # if lo <= c <= hi: | |||||
# # logger.info( | |||||
# # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples) | |||||
# # ) | |||||
# # return triples | |||||
# # ============ parallel block: just make `edits_per_eval` proposals | |||||
# needed = min(edits_per_eval, max_iters - edit_counter) | |||||
# proposals = [] | |||||
# while len(proposals) < needed: | |||||
# # launch a batch of workers | |||||
# futs = [ | |||||
# pool.submit( | |||||
# make_edit_proposal, | |||||
# triples, | |||||
# colours, | |||||
# c, | |||||
# lo, | |||||
# hi, | |||||
# n_ent, | |||||
# n_rel, | |||||
# depth, | |||||
# proposal_seed + i, | |||||
# ) | |||||
# for i in range(batch_size) | |||||
# ] | |||||
# for f in as_completed(futs): | |||||
# prop = f.result() | |||||
# if prop is not None: | |||||
# proposals.append(prop) | |||||
# if len(proposals) == needed: | |||||
# break | |||||
# proposal_seed += batch_size # move RNG window forward | |||||
# # -------------- apply the gathered proposals *sequentially* ------- | |||||
# for kind, trp in proposals: | |||||
# if kind == "add": | |||||
# triples.append(trp) | |||||
# else: # "rem" | |||||
# try: | |||||
# triples.remove(trp) | |||||
# except ValueError: | |||||
# pass # already gone – benign collision | |||||
# # ----------------------------------------------------------------- | |||||
# edit_counter += needed | |||||
# if edit_counter % 1_000 == 0: | |||||
# logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples)) | |||||
# raise RuntimeError("Exceeded max_iters without hitting target band.") | |||||
import os | |||||
import random | |||||
import threading | |||||
from collections import defaultdict | |||||
from typing import Dict, List, Tuple | |||||
# -------------------------------------------------------------------- | |||||
# Logging helper (replace with your own if you prefer) -------------- | |||||
# -------------------------------------------------------------------- | |||||
from tools import get_pretty_logger | |||||
logging = get_pretty_logger(__name__) | |||||
# -------------------------------------------------------------------- | |||||
# Original edge‑editing primitives (unchanged) ---------------------- | |||||
# -------------------------------------------------------------------- | |||||
def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||||
for _ in range(1000): | |||||
h = rng.randrange(n_ent) | |||||
t = rng.randrange(n_ent) | |||||
if colours[h] != colours[t]: | |||||
r = rng.randrange(n_rel) | |||||
triples.append((h, r, t)) | |||||
out_adj[h].append((r, t)) | |||||
in_adj[t].append((r, h)) | |||||
return | |||||
def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||||
σ = rng.choice(colours) | |||||
τ = rng.choice(colours) | |||||
h = colours.index(σ) | |||||
t = colours.index(τ) | |||||
triples.append((h, 0, t)) | |||||
out_adj[h].append((0, t)) | |||||
in_adj[t].append((0, h)) | |||||
def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||||
if cand: | |||||
idx = rng.choice(cand) | |||||
h, r, t = triples.pop(idx) | |||||
out_adj[h].remove((r, t)) | |||||
in_adj[t].remove((r, h)) | |||||
def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||||
sig_rel = defaultdict(set) | |||||
for h, r, t in triples: | |||||
σ, τ = colours[h], colours[t] | |||||
sig_rel[(σ, τ)].add(r) | |||||
cand = [] | |||||
for i, (h, r, t) in enumerate(triples): | |||||
if len(sig_rel[(colours[h], colours[t])]) == 1: | |||||
cand.append(i) | |||||
if cand: | |||||
idx = rng.choice(cand) | |||||
h, r, t = triples.pop(idx) | |||||
out_adj[h].remove((r, t)) | |||||
in_adj[t].remove((r, h)) | |||||
def _worker( | |||||
*, | |||||
worker_id: int, | |||||
rng: random.Random, | |||||
triples: List[Tuple[int, int, int]], | |||||
out_adj: Dict[int, List[Tuple[int, int]]], | |||||
in_adj: Dict[int, List[Tuple[int, int]]], | |||||
colours: List[int], | |||||
n_ent: int, | |||||
n_rel: int, | |||||
depth: int, | |||||
c: float, | |||||
lo: float, | |||||
hi: float, | |||||
max_iters: int, | |||||
state_lock: threading.Lock, | |||||
stop_event: threading.Event, | |||||
): | |||||
"""One thread: mutate the *shared* structures until success/stop.""" | |||||
for it in range(max_iters): | |||||
if stop_event.is_set(): | |||||
return # someone else finished ─ exit early | |||||
with state_lock: # protect the shared graph | |||||
# if lo <= c <= hi: | |||||
# logging.info( | |||||
# "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)", | |||||
# worker_id, | |||||
# it, | |||||
# c, | |||||
# len(triples), | |||||
# ) | |||||
# stop_event.set() | |||||
# return | |||||
# Choose and apply one edit ----------------------------- | |||||
if c < lo: # need ↑ CREC | |||||
if rng.random() < 0.5: | |||||
rem_det_edge(triples, out_adj, in_adj, colours, depth, rng) | |||||
else: | |||||
add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng) | |||||
elif rng.random() < 0.5: | |||||
rem_div_edge(triples, out_adj, in_adj, colours, depth, rng) | |||||
else: | |||||
add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng) | |||||
logging.warning("[worker %d] reached max_iters", worker_id) | |||||
stop_event.set() | |||||
return | |||||
# -------------------------------------------------------------------- | |||||
# Public API -------------------------------------------------------- | |||||
# -------------------------------------------------------------------- | |||||
def fast_tune_crec( | |||||
triples: List, | |||||
colours: List, | |||||
n_ent: int, | |||||
n_rel: int, | |||||
depth: int, | |||||
c: float, | |||||
lo: float, | |||||
hi: float, | |||||
max_iters: int = 1000, | |||||
max_workers: int | None = None, | |||||
seeds: List[int] | None = None, | |||||
) -> List[Tuple[int, int, int]]: | |||||
"""Tune WL‑CREC with *shared* triples using multiple threads. | |||||
Returns the **same list instance** that was passed in – already | |||||
modified in place by the winning thread. | |||||
""" | |||||
if max_workers is None: | |||||
# max_workers = os.cpu_count() or 4 | |||||
max_workers = 4 | |||||
if seeds is None: | |||||
seeds = [42 + i for i in range(max_workers)] | |||||
assert len(seeds) >= max_workers, "Need at least one seed per worker" | |||||
# Prepare adjacency once (shared) -------------------------------- | |||||
out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||||
in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||||
for h, r, t in triples: | |||||
out_adj[h].append((r, t)) | |||||
in_adj[t].append((r, h)) | |||||
state_lock = threading.Lock() | |||||
stop_event = threading.Event() | |||||
logging.info( | |||||
"Launching %d threads on shared triples (target %.3f–%.3f)", | |||||
max_workers, | |||||
lo, | |||||
hi, | |||||
) | |||||
threads: List[threading.Thread] = [] | |||||
for wid in range(max_workers): | |||||
t = threading.Thread( | |||||
name=f"tune‑crec‑worker‑{wid}", | |||||
target=_worker, | |||||
kwargs=dict( | |||||
worker_id=wid, | |||||
rng=random.Random(seeds[wid]), | |||||
triples=triples, | |||||
out_adj=out_adj, | |||||
in_adj=in_adj, | |||||
colours=colours, | |||||
n_ent=n_ent, | |||||
n_rel=n_rel, | |||||
depth=depth, | |||||
c=c, | |||||
lo=lo, | |||||
hi=hi, | |||||
max_iters=max_iters, | |||||
state_lock=state_lock, | |||||
stop_event=stop_event, | |||||
), | |||||
daemon=False, | |||||
) | |||||
threads.append(t) | |||||
t.start() | |||||
for t in threads: | |||||
t.join() | |||||
if not stop_event.is_set(): | |||||
raise RuntimeError("No thread converged – try increasing max_iters or widen band") | |||||
return triples |
from __future__ import annotations | |||||
import logging | |||||
import random | |||||
from collections import defaultdict, deque | |||||
from math import isclose | |||||
from typing import List | |||||
import torch | |||||
from data.kg_dataset import KGDataset | |||||
from metrics.wlcrec import WLCREC | |||||
from tools import get_pretty_logger | |||||
logging = get_pretty_logger(__name__) | |||||
# -------------------------------------------------------------------- | |||||
# Radius-k patch sampler | |||||
# -------------------------------------------------------------------- | |||||
def radius_patch_sample( | |||||
all_triples: List, | |||||
ratio: float, | |||||
radius: int, | |||||
rng: random.Random, | |||||
) -> List: | |||||
""" | |||||
Take ~ratio fraction of triples by picking random seeds and | |||||
keeping *all* triples within ≤ radius hops (undirected) of them. | |||||
""" | |||||
n_total = len(all_triples) | |||||
target = int(ratio * n_total) | |||||
# Build entity → triple-indices adjacency for BFS | |||||
ent2trip = defaultdict(list) | |||||
for idx, (h, _, t) in enumerate(all_triples): | |||||
ent2trip[h].append(idx) | |||||
ent2trip[t].append(idx) | |||||
chosen_idx: set[int] = set() | |||||
visited_ent = set() | |||||
while len(chosen_idx) < target: | |||||
seed_idx = rng.randrange(n_total) | |||||
if seed_idx in chosen_idx: | |||||
continue | |||||
# BFS over entities | |||||
queue = deque() | |||||
for e in all_triples[seed_idx][::2]: # subject & object | |||||
if e not in visited_ent: | |||||
queue.append((e, 0)) | |||||
visited_ent.add(e) | |||||
while queue: | |||||
ent, dist = queue.popleft() | |||||
for tidx in ent2trip[ent]: | |||||
if tidx not in chosen_idx: | |||||
chosen_idx.add(tidx) | |||||
if dist < radius: | |||||
for tidx in ent2trip[ent]: | |||||
h, _, t = all_triples[tidx] | |||||
for nb in (h, t): | |||||
if nb not in visited_ent: | |||||
visited_ent.add(nb) | |||||
queue.append((nb, dist + 1)) | |||||
# guard against infinite loop on very small ratio | |||||
if isclose(ratio, 0.0) and chosen_idx: | |||||
break | |||||
return [all_triples[i] for i in chosen_idx] | |||||
# -------------------------------------------------------------------- | |||||
# Main driver: repeated sampling until WL-CREC band hit | |||||
# -------------------------------------------------------------------- | |||||
def find_sample( | |||||
all_triples: List, | |||||
n_ent: int, | |||||
n_rel: int, | |||||
depth: int, | |||||
ratio: float, | |||||
radius: int, | |||||
lower: float, | |||||
upper: float, | |||||
max_tries: int, | |||||
seed: int, | |||||
) -> List: | |||||
for trial in range(1, max_tries + 1): | |||||
rng = random.Random(seed + trial) # fresh randomness each try | |||||
sample = radius_patch_sample(all_triples, ratio, radius, rng) | |||||
sample_ds = KGDataset( | |||||
triples_factory=None, | |||||
triples=torch.tensor(sample, dtype=torch.long), | |||||
num_entities=n_ent, | |||||
num_relations=n_rel, | |||||
split="train", | |||||
) | |||||
wl_crec = WLCREC(sample_ds) | |||||
crec_val = wl_crec.compute(depth)[-2] | |||||
logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val) | |||||
if lower <= crec_val <= upper: | |||||
logging.info("Success after %d tries.", trial) | |||||
return sample | |||||
raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.") |
from __future__ import annotations | |||||
import logging | |||||
import random | |||||
from math import log | |||||
from typing import Dict, List, Tuple | |||||
import torch | |||||
from torch import Tensor | |||||
from tools import get_pretty_logger | |||||
logging = get_pretty_logger(__name__) | |||||
Entity = int | |||||
Relation = int | |||||
Triple = Tuple[Entity, Relation, Entity] | |||||
Color = int | |||||
# --------------------------------------------------------------------- | |||||
# Weisfeiler–Lehman colouring (GPU friendly) | |||||
# --------------------------------------------------------------------- | |||||
def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]: | |||||
""" | |||||
GPU implementation of classic Weisfeiler‑Lehman refinement. | |||||
1. Each node starts with colour 0. | |||||
2. At iteration h we hash the multiset of | |||||
⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h)) | |||||
signatures into new colour ids. | |||||
3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and | |||||
immaterial for CREC (only *relative* colours matter). | |||||
""" | |||||
h, r, t = triples.unbind(dim=1) # (|T|,) | |||||
colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0 | |||||
# pre‑compute 64‑bit mix constants once | |||||
mix_r = ( | |||||
torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long) | |||||
* 1_146_189_683_093_321_123 | |||||
) | |||||
mix_c = 8_636_673_225_737_527_201 | |||||
for _ in range(depth): | |||||
prev = colours[-1] # (|V|,) | |||||
# signatures for outgoing edges | |||||
sig_out = mix_r[r] ^ (prev[t] * mix_c) | |||||
# signatures for incoming edges | |||||
sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15 | |||||
# bucket signatures back to the source/target nodes | |||||
col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out) | |||||
col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in) | |||||
# combine ↓ and ↑ multiset hashes with current colour | |||||
raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1) | |||||
# re‑map to dense consecutive ids with torch.unique | |||||
uniq, new = torch.unique(raw, sorted=True, return_inverse=True) | |||||
colours.append(new) | |||||
return colours # length = depth+1, each (|V|,) long | |||||
# --------------------------------------------------------------------- | |||||
# WL‑CREC metric | |||||
# --------------------------------------------------------------------- | |||||
def wl_crec_gpu( | |||||
triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device | |||||
) -> float: | |||||
log_n = log(max(n_ent, 2)) | |||||
# ------- pattern‑diversity C ------------------------------------- | |||||
C = 0.0 | |||||
for c in colours: # (|V|,) | |||||
# histogram via bincount on GPU | |||||
hist = torch.bincount(c).float() | |||||
p = hist / n_ent | |||||
C += -(p * torch.log(p.clamp_min(1e-30))).sum().item() | |||||
C /= (depth + 1) * log_n | |||||
# ------- residual entropy H_c ------------------------------------ | |||||
h, r, t = triples.unbind(dim=1) | |||||
sigma, tau = colours[depth][h], colours[depth][t] # (|T|,) | |||||
# encode (sigma,tau) pairs into a single 64‑bit key | |||||
sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64) | |||||
# total count per signature | |||||
sig_unique, sig_inv, sig_counts = torch.unique( | |||||
sig_keys, return_inverse=True, return_counts=True, sorted=False | |||||
) | |||||
# build 2‑D contingency table counts[(sigma,tau), r] | |||||
m = triples.size(0) | |||||
keys_2d = sig_inv * n_rel + r | |||||
_, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False) | |||||
# rel_counts is aligned with the *compact* key list; rebuild dense tensor | |||||
rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long) | |||||
rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long)) | |||||
# conditional entropy per (sigma,tau) | |||||
p_r = rc_dense.float() / sig_counts.unsqueeze(1).float() | |||||
inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,) | |||||
Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2)) | |||||
return C * Hc | |||||
# --------------------------------------------------------------------- | |||||
# delta‑entropy helpers (still CPU side for clarity; cost is negligible) | |||||
# --------------------------------------------------------------------- | |||||
# The deterministic delta‑formulas depend on small dictionaries and are evaluated | |||||
# on ≤256 candidate edges each iteration – copy them from the original script. | |||||
inner_entropy = ... # unchanged | |||||
delta_h_cond_remove = ... # unchanged | |||||
delta_h_cond_add = ... # unchanged | |||||
# --------------------------------------------------------------------- | |||||
# Greedy delta‑search driver | |||||
# --------------------------------------------------------------------- | |||||
def greedy_delta_tune_gpu( | |||||
triples: List[Triple], | |||||
n_ent: int, | |||||
n_rel: int, | |||||
depth: int, | |||||
lower: float, | |||||
upper: float, | |||||
device: torch.device, | |||||
max_iters: int = 40_000, | |||||
sample_size: int = 256, | |||||
seed: int = 0, | |||||
) -> List[Triple]: | |||||
# book‑keeping in Python lists (cheap); heavy math in torch on GPU | |||||
rng = random.Random(seed) | |||||
# cache torch version of triples for fast metric eval | |||||
def triples_to_tensor(ts: List[Triple]) -> Tensor: | |||||
return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False) | |||||
for it in range(max_iters): | |||||
tri_tensor = triples_to_tensor(triples) | |||||
colours = wl_colours_gpu(tri_tensor, n_ent, depth, device) | |||||
crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device) | |||||
if lower <= crec <= upper: | |||||
logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples)) | |||||
return triples | |||||
# ---------------------------------------------------------------- | |||||
# build signature statistics (CPU, cheap: O(|T|)) | |||||
# ---------------------------------------------------------------- | |||||
depth_col = colours[depth].cpu() | |||||
sig_cnt: Dict[Tuple[int, int], int] = {} | |||||
rel_cnt: Dict[Tuple[int, int, int], int] = {} | |||||
for h_idx, (h, r, t) in enumerate(triples): | |||||
sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||||
sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1 | |||||
rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1 | |||||
det_edges, div_edges = [], [] | |||||
for idx, (h, r, t) in enumerate(triples): | |||||
sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||||
if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1: | |||||
det_edges.append(idx) | |||||
if sigma != tau: | |||||
div_edges.append(idx) | |||||
# ---------------------------------------------------------------- | |||||
# candidate generation + best edit selection | |||||
# ---------------------------------------------------------------- | |||||
total = len(triples) | |||||
target_high = crec < lower # need to raise WL‑CREC? | |||||
candidates = [] | |||||
if target_high and det_edges: | |||||
rng.shuffle(det_edges) | |||||
for idx in det_edges[:sample_size]: | |||||
h, r, t = triples[idx] | |||||
sig = (int(depth_col[h]), int(depth_col[t])) | |||||
delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||||
if delta > 0: | |||||
candidates.append(("remove", idx, delta)) | |||||
elif not target_high and det_edges: | |||||
rng.shuffle(det_edges) | |||||
for idx in det_edges[:sample_size]: | |||||
h, r, t = triples[idx] | |||||
sig = (int(depth_col[h]), int(depth_col[t])) | |||||
delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||||
if delta < 0: | |||||
candidates.append(("add", (h, r, t), delta)) | |||||
# fall‑back heuristics | |||||
if not candidates: | |||||
if target_high and div_edges: | |||||
idx = rng.choice(div_edges) | |||||
candidates.append(("remove", idx, 1e-9)) | |||||
elif not target_high: | |||||
idx = rng.choice(det_edges) if det_edges else rng.randrange(total) | |||||
h, r, t = triples[idx] | |||||
candidates.append(("add", (h, r, t), -1e-9)) | |||||
best = ( | |||||
max(candidates, key=lambda x: x[2]) | |||||
if target_high | |||||
else min(candidates, key=lambda x: x[2]) | |||||
) | |||||
# apply edit | |||||
if best[0] == "remove": | |||||
triples.pop(best[1]) | |||||
else: | |||||
triples.append(best[1]) | |||||
if (it + 1) % 1_000 == 0: | |||||
logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples)) | |||||
raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.") |
import torch | |||||
from models.base_model import ERModel | |||||
def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |||||
return (scores > scores[target]).sum() + 1 | |||||
def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor: | |||||
h, r, t = triples[:, 0], triples[:, 1], triples[:, 2] | |||||
E = model.num_entities | |||||
ranks = [] | |||||
with torch.no_grad(): | |||||
for hi, ri, ti in zip(h, r, t): | |||||
tails = torch.arange(E, device=triples.device) | |||||
triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1) | |||||
scores = model(triples) | |||||
ranks.append(_rank(scores, ti)) | |||||
return torch.tensor(ranks, device=triples.device, dtype=torch.float32) | |||||
def mrr(r: torch.Tensor) -> torch.Tensor: | |||||
return (1 / r).mean() | |||||
def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor: | |||||
return (r <= k).float().mean() |
from __future__ import annotations | |||||
import math | |||||
from collections import Counter, defaultdict | |||||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||||
import torch | |||||
from .base_metric import BaseMetric | |||||
if TYPE_CHECKING: | |||||
from data.kg_dataset import KGDataset | |||||
def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||||
"""Return Shannon entropy (nats) of a colour distribution of length *n*.""" | |||||
ent = 0.0 | |||||
for cnt in counter.values(): | |||||
if cnt: | |||||
p = cnt / n | |||||
ent -= p * math.log(p) | |||||
return ent | |||||
def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||||
"""Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies.""" | |||||
if n <= 1: | |||||
return 0.0, 0.0 | |||||
H_plus_1 = len(entropies) | |||||
log_n = math.log(n) | |||||
c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1 | |||||
c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1 | |||||
return c_ratio, c_nwlec | |||||
def _relation_families( | |||||
triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9 | |||||
) -> torch.Tensor: | |||||
"""Return a LongTensor mapping each relation id → family id. | |||||
Two relations *r, r_inv* are put into the same family when **at least | |||||
`thresh` fraction** of the edges labelled *r* have a **single** reverse edge | |||||
labelled *r_inv*. | |||||
""" | |||||
assert 0.0 < thresh <= 1.0 | |||||
# Build map (h,t) -> list of relations on that edge direction. | |||||
edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list) | |||||
for h, r, t in triples.tolist(): | |||||
edge2rels[(h, t)].append(int(r)) | |||||
# Count reciprocal co‑occurrences (r, r_rev). | |||||
pair_counts: Dict[Tuple[int, int], int] = defaultdict(int) | |||||
rel_totals: List[int] = [0] * num_relations | |||||
for (h, t), rels in edge2rels.items(): | |||||
rev_rels = edge2rels.get((t, h)) | |||||
if not rev_rels: | |||||
continue | |||||
for r in rels: | |||||
rel_totals[r] += 1 | |||||
for r_rev in rev_rels: | |||||
pair_counts[(r, r_rev)] += 1 | |||||
# Proposed inverse mapping r -> r_inv. | |||||
proposed: List[int | None] = [None] * num_relations | |||||
for (r, r_rev), cnt in pair_counts.items(): | |||||
if cnt == 0: | |||||
continue | |||||
if cnt / rel_totals[r] >= thresh: | |||||
# majority of r's edges are reversed by r_rev | |||||
proposed[r] = r_rev | |||||
# Build family ids (transitive closure is overkill; data are simple). | |||||
family = list(range(num_relations)) | |||||
for r, rinv in enumerate(proposed): | |||||
if rinv is not None: | |||||
root = min(family[r], family[rinv]) | |||||
family[r] = family[rinv] = root | |||||
return torch.tensor(family, dtype=torch.long) | |||||
def _conditional_relation_entropy( | |||||
colours: List[int], | |||||
triples: torch.Tensor, | |||||
family_map: torch.Tensor, | |||||
) -> float: | |||||
counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int)) | |||||
m = int(triples.size(0)) | |||||
fam = family_map # alias for speed | |||||
for h, r, t in triples.tolist(): | |||||
# key = (colours[h], colours[t]) # ordered key (direction‑aware) | |||||
key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic) | |||||
counts[key][int(fam[r])] += 1 # ▼ use family id | |||||
h_cond = 0.0 | |||||
for rel_counts in counts.values(): | |||||
total_s = sum(rel_counts.values()) | |||||
P_s = total_s / m | |||||
inv_total = 1.0 / total_s | |||||
for cnt in rel_counts.values(): | |||||
p_rs = cnt * inv_total | |||||
h_cond -= P_s * p_rs * math.log(p_rs) | |||||
return h_cond | |||||
class WLCREC(BaseMetric): | |||||
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||||
def __init__(self, dataset: KGDataset) -> None: | |||||
super().__init__(dataset) | |||||
def compute( | |||||
self, | |||||
H: int = 3, | |||||
cond_h: int = 1, | |||||
inv_thresh: float = 0.9, | |||||
return_full: bool = False, | |||||
progress: bool = True, | |||||
) -> float | Tuple[float, List[float]]: | |||||
"""Compute WL-entropy scores and the composite difficulty metric. | |||||
Parameters | |||||
---------- | |||||
dataset | |||||
Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``. | |||||
H | |||||
WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers. | |||||
return_full | |||||
Also return the list of per-layer entropies. | |||||
progress | |||||
Print textual progress bar. | |||||
Returns | |||||
------- | |||||
avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC | |||||
*Six* scalars (floats). If ``return_full=True`` an extra list of | |||||
layer entropies is appended. | |||||
""" | |||||
# compute relation family map once | |||||
family_map = _relation_families( | |||||
self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh | |||||
) | |||||
# ─── WL iterations ──────────────────────────────────────────────────── | |||||
n = self.dataset.num_entities | |||||
colour: List[int] = [1] * n # colour 1 for everyone | |||||
entropies: List[float] = [] | |||||
colours_per_h: List[List[int]] = [] | |||||
def _print_prog(i: int) -> None: | |||||
if progress: | |||||
print(f"\rWL iteration {i}/{H}", end="", flush=True) | |||||
for h in range(H + 1): | |||||
_print_prog(h) | |||||
colours_per_h.append(colour.copy()) | |||||
# entropy of current colouring | |||||
freq = Counter(colour) | |||||
entropies.append(_entropy_from_counter(freq, n)) | |||||
if h == H: | |||||
break | |||||
# refine colours | |||||
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||||
next_colour: List[int] = [0] * n | |||||
for v in range(n): | |||||
T: List[Tuple[int, int, int]] = [] | |||||
for r, t in self.out_adj[v]: | |||||
T.append((r, 0, colour[t])) # outgoing | |||||
for r, h_ in self.in_adj[v]: | |||||
T.append((r, 1, colour[h_])) # incoming | |||||
T.sort() | |||||
key = (colour[v], tuple(T)) | |||||
if key not in bucket: | |||||
bucket[key] = len(bucket) + 1 | |||||
next_colour[v] = bucket[key] | |||||
colour = next_colour | |||||
_print_prog(H) | |||||
if progress: | |||||
print() | |||||
# ─── Normalised diversity ──────────────────────────────────────────── | |||||
avg_entropy = sum(entropies) / len(entropies) | |||||
c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||||
# ─── Residual relation entropy ─────────────────────────────────────── | |||||
h_cond = _conditional_relation_entropy( | |||||
colours_per_h[cond_h], self.dataset.triples, family_map | |||||
) | |||||
# ─── Composite difficulty - Eq. (1) ────────────────────────────────── | |||||
d_ratio = c_ratio * h_cond | |||||
d_nwlec = c_nwlec * h_cond | |||||
result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec) | |||||
if return_full: | |||||
result += (entropies,) | |||||
return result | |||||
# -------------------------------------------------------------------- | |||||
# Weisfeiler–Lehman colours (unchanged) | |||||
# -------------------------------------------------------------------- | |||||
def wl_colours(self, depth: int) -> List[List[int]]: | |||||
triples = self.dataset.triples.tolist() | |||||
out_adj, in_adj = defaultdict(list), defaultdict(list) | |||||
for h, r, t in triples: | |||||
out_adj[h].append((r, t)) | |||||
in_adj[t].append((r, h)) | |||||
n_ent = self.dataset.num_entities | |||||
colours = [[0] * n_ent] # round-0 | |||||
for h in range(1, depth + 1): | |||||
prev, nxt, sig2c = colours[-1], [0] * n_ent, {} | |||||
fresh = 0 | |||||
for v in range(n_ent): | |||||
neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||||
("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||||
] | |||||
neigh.sort() | |||||
sig = (prev[v], tuple(neigh)) | |||||
if sig not in sig2c: | |||||
sig2c[sig] = fresh | |||||
fresh += 1 | |||||
nxt[v] = sig2c[sig] | |||||
colours.append(nxt) | |||||
return colours |
from __future__ import annotations | |||||
import math | |||||
from collections import Counter | |||||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||||
from .base_metric import BaseMetric | |||||
if TYPE_CHECKING: | |||||
from data.kg_dataset import KGDataset | |||||
def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||||
"""Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts.""" | |||||
ent = 0.0 | |||||
for count in counter.values(): | |||||
if count: # avoid log(0) | |||||
p = count / n | |||||
ent -= p * math.log(p) | |||||
return ent | |||||
def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||||
"""Return (C_ratio, C_NWLEC) given *per-iteration* entropies.""" | |||||
if n <= 1: | |||||
# Degenerate graph. | |||||
return 0.0, 0.0 | |||||
H_plus_1 = len(entropies) | |||||
# Eq. (1): simple ratio normalisation. | |||||
c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1 | |||||
# Eq. (2): effective-colour NWLEC. | |||||
k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies] | |||||
c_nwlec = sum(k_terms) / H_plus_1 | |||||
return c_ratio, c_nwlec | |||||
class WLEC(BaseMetric): | |||||
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||||
def __init__(self, dataset: KGDataset) -> None: | |||||
super().__init__(dataset) | |||||
def compute( | |||||
self, | |||||
H: int = 3, | |||||
*, | |||||
return_full: bool = False, | |||||
progress: bool = True, | |||||
) -> float | Tuple[float, List[float]]: | |||||
"""Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC). | |||||
Parameters | |||||
---------- | |||||
dataset: | |||||
Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``), | |||||
``num_entities`` and ``num_relations``. | |||||
H: | |||||
Number of refinement iterations **after** the initial colouring (so the | |||||
colours are updated **H** times and the entropy is measured **H+1** times). | |||||
return_full: | |||||
If *True*, additionally return the list of entropies for each depth. | |||||
progress: | |||||
Whether to print a mini progress bar (no external dependency). | |||||
Returns | |||||
------- | |||||
average_entropy or (average_entropy, entropies) | |||||
Average of the stored entropies, and optionally the list itself. | |||||
""" | |||||
n = int(self.num_entities) | |||||
if n == 0: | |||||
msg = "Dataset appears to contain zero entities." | |||||
raise ValueError(msg) | |||||
# Colour assignment - we keep it as a *list* of ints for cheap hashing. | |||||
# Start with colour 1 for every entity (index 0 unused). | |||||
colour: List[int] = [1] * n | |||||
entropies: List[float] = [] | |||||
# Optional poor-man's progress bar. | |||||
def _print_prog(i: int) -> None: | |||||
if progress: | |||||
print(f"\rIteration {i}/{H}", end="", flush=True) | |||||
for h in range(H + 1): | |||||
_print_prog(h) | |||||
# ------------------------------------------------------------------ | |||||
# Step 1: measure entropy of current colouring ---------------------- | |||||
# ------------------------------------------------------------------ | |||||
freq = Counter(colour) | |||||
ent = _entropy_from_counter(freq, n) | |||||
entropies.append(ent) | |||||
if h == H: | |||||
break # We have reached the requested depth - stop here. | |||||
# ------------------------------------------------------------------ | |||||
# Step 2: create refined colours ----------------------------------- | |||||
# ------------------------------------------------------------------ | |||||
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||||
next_colour: List[int] = [0] * n | |||||
for v in range(n): | |||||
# Build the multiset T containing outgoing and incoming features. | |||||
T: List[Tuple[int, int, int]] = [] | |||||
for r, t in self.out_adj[v]: | |||||
T.append((r, 0, colour[t])) # 0 = outgoing (\u2193) | |||||
for r, h_ in self.in_adj[v]: | |||||
T.append((r, 1, colour[h_])) # 1 = incoming (\u2191) | |||||
T.sort() # canonical ordering | |||||
key = (colour[v], tuple(T)) | |||||
if key not in bucket: | |||||
bucket[key] = len(bucket) + 1 # new colour ID (start at 1) | |||||
next_colour[v] = bucket[key] | |||||
colour = next_colour # move to next iteration | |||||
_print_prog(H) | |||||
if progress: | |||||
print() # newline after progress bar | |||||
average_entropy = sum(entropies) / len(entropies) | |||||
c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||||
if return_full: | |||||
return average_entropy, entropies, c_ratio, c_nwlec | |||||
return average_entropy, c_ratio, c_nwlec |
from __future__ import annotations | |||||
from abc import ABC, abstractmethod | |||||
from typing import TYPE_CHECKING | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn.functional import log_softmax, softmax | |||||
if TYPE_CHECKING: | |||||
from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target | |||||
class ERModel(nn.Module, ABC): | |||||
"""Base class for knowledge graph models.""" | |||||
def __init__(self, num_entities: int, num_relations: int, dim: int) -> None: | |||||
super().__init__() | |||||
self.num_entities = num_entities | |||||
self.num_relations = num_relations | |||||
self.dim = dim | |||||
self.inner_model = None | |||||
@property | |||||
def device(self) -> torch.device: | |||||
"""Return the device of the model.""" | |||||
return next(self.inner_model.parameters()).device | |||||
@abstractmethod | |||||
def reset_parameters(self, *args, **kwargs) -> None: | |||||
"""Reset the parameters of the model.""" | |||||
def predict( | |||||
self, | |||||
hrt_batch: MappedTriples, | |||||
target: Target, | |||||
full_batch: bool = True, | |||||
ids: LongTensor | None = None, | |||||
**kwargs, | |||||
) -> FloatTensor: | |||||
if not self.inner_model: | |||||
msg = ( | |||||
"Inner model is not set. Please initialize the inner model before calling predict." | |||||
) | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
return self.inner_model.predict( | |||||
hrt_batch=hrt_batch, | |||||
target=target, | |||||
full_batch=full_batch, | |||||
ids=ids, | |||||
**kwargs, | |||||
) | |||||
def forward( | |||||
self, | |||||
triples: MappedTriples, | |||||
slice_size: int | None = None, | |||||
slice_dim: int = 0, | |||||
*, | |||||
mode: InductiveMode | None, | |||||
) -> FloatTensor: | |||||
h_indices = triples[:, 0] | |||||
r_indices = triples[:, 1] | |||||
t_indices = triples[:, 2] | |||||
if not self.inner_model: | |||||
msg = ( | |||||
"Inner model is not set. Please initialize the inner model before calling forward." | |||||
) | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
return self.inner_model.forward( | |||||
h_indices=h_indices, | |||||
r_indices=r_indices, | |||||
t_indices=t_indices, | |||||
slice_size=slice_size, | |||||
slice_dim=slice_dim, | |||||
mode=mode, | |||||
) | |||||
@torch.inference_mode() | |||||
def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r) | |||||
self, | |||||
hr_batch: LongTensor, # (B, 2) : (h,r) | |||||
*, | |||||
slice_size: int | None = None, | |||||
mode: InductiveMode | None = None, | |||||
log: bool = False, | |||||
) -> FloatTensor: # (B, |E|) | |||||
scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode) | |||||
return log_softmax(scores, -1) if log else softmax(scores, -1) | |||||
# ----------------------------------------------------------------------- # | |||||
# (1) HEAD-given-(relation, tail) → pθ(h | r,t) # | |||||
# ----------------------------------------------------------------------- # | |||||
@torch.inference_mode() | |||||
def head_distribution( | |||||
self, | |||||
rt_batch: LongTensor, # (B, 2) : (r,t) | |||||
*, | |||||
slice_size: int | None = None, | |||||
mode: InductiveMode | None = None, | |||||
log: bool = False, | |||||
) -> FloatTensor: # (B, |E|) | |||||
scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode) | |||||
return log_softmax(scores, -1) if log else softmax(scores, -1) | |||||
# ----------------------------------------------------------------------- # | |||||
# (2) RELATION-given-(head, tail) → pθ(r | h,t) # | |||||
# ----------------------------------------------------------------------- # | |||||
@torch.inference_mode() | |||||
def relation_distribution( | |||||
self, | |||||
ht_batch: LongTensor, # (B, 2) : (h,t) | |||||
*, | |||||
slice_size: int | None = None, | |||||
mode: InductiveMode | None = None, | |||||
log: bool = False, | |||||
) -> FloatTensor: # (B, |R|) | |||||
scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode) | |||||
return log_softmax(scores, -1) if log else softmax(scores, -1) |
from __future__ import annotations | |||||
from typing import TYPE_CHECKING | |||||
import torch | |||||
from pykeen.models import TransE as PyKEENTransE | |||||
from torch import nn | |||||
from models.base_model import ERModel | |||||
if TYPE_CHECKING: | |||||
from pykeen.typing import MappedTriples | |||||
class TransE(ERModel): | |||||
""" | |||||
TransE model for knowledge graph embedding. | |||||
score = ||h + r - t|| | |||||
""" | |||||
def __init__( | |||||
self, | |||||
num_entities: int, | |||||
num_relations: int, | |||||
triples_factory: MappedTriples, | |||||
dim: int = 200, | |||||
sf_norm: int = 1, | |||||
p_norm: bool = True, | |||||
margin: float | None = None, | |||||
epsilon: float | None = None, | |||||
device: torch.device = torch.device("cpu"), | |||||
) -> None: | |||||
super().__init__(num_entities, num_relations, dim) | |||||
self.dim = dim | |||||
self.sf_norm = sf_norm | |||||
self.p_norm = p_norm | |||||
self.margin = margin | |||||
self.epsilon = epsilon | |||||
self.inner_model = PyKEENTransE( | |||||
triples_factory=triples_factory, | |||||
embedding_dim=dim, | |||||
scoring_fct_norm=sf_norm, | |||||
power_norm=p_norm, | |||||
random_seed=42, | |||||
).to(device) | |||||
def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||||
# Parameter initialization | |||||
if margin is None or epsilon is None: | |||||
# If no margin/epsilon are provided, use Xavier uniform initialization | |||||
nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||||
nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||||
else: | |||||
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||||
self.embedding_range = nn.Parameter( | |||||
torch.Tensor([(margin + epsilon) / self.dim]), | |||||
requires_grad=False, | |||||
) | |||||
nn.init.uniform_( | |||||
tensor=self.ent_embeddings.weight.data, | |||||
a=-self.embedding_range.item(), | |||||
b=self.embedding_range.item(), | |||||
) | |||||
nn.init.uniform_( | |||||
tensor=self.rel_embeddings.weight.data, | |||||
a=-self.embedding_range.item(), | |||||
b=self.embedding_range.item(), | |||||
) |
from __future__ import annotations | |||||
from typing import TYPE_CHECKING | |||||
import torch | |||||
from pykeen.losses import MarginRankingLoss | |||||
from pykeen.models import TransH as PyKEENTransH | |||||
from pykeen.regularizers import LpRegularizer | |||||
from torch import nn | |||||
from models.base_model import ERModel | |||||
if TYPE_CHECKING: | |||||
from pykeen.typing import MappedTriples | |||||
class TransH(ERModel): | |||||
""" | |||||
TransH model for knowledge graph embedding. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
num_entities: int, | |||||
num_relations: int, | |||||
triples_factory: MappedTriples, | |||||
dim: int = 200, | |||||
p_norm: bool = True, | |||||
margin: float | None = None, | |||||
epsilon: float | None = None, | |||||
device: torch.device = torch.device("cpu"), | |||||
**kwargs, | |||||
) -> None: | |||||
super().__init__(num_entities, num_relations, dim) | |||||
self.dim = dim | |||||
self.p_norm = p_norm | |||||
self.margin = margin | |||||
self.epsilon = epsilon | |||||
self.inner_model = PyKEENTransH( | |||||
triples_factory=triples_factory, | |||||
embedding_dim=dim, | |||||
scoring_fct_norm=1, | |||||
loss=self.loss, | |||||
power_norm=p_norm, | |||||
entity_initializer="xavier_uniform_", # good default | |||||
relation_initializer="xavier_uniform_", | |||||
random_seed=42, | |||||
).to(device) | |||||
@property | |||||
def loss(self) -> torch.Tensor: | |||||
return MarginRankingLoss(margin=self.margin, reduction="mean") | |||||
def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||||
# Parameter initialization | |||||
if margin is None or epsilon is None: | |||||
# If no margin/epsilon are provided, use Xavier uniform initialization | |||||
nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||||
nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||||
else: | |||||
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||||
self.embedding_range = nn.Parameter( | |||||
torch.Tensor([(margin + epsilon) / self.dim]), | |||||
requires_grad=False, | |||||
) | |||||
nn.init.uniform_( | |||||
tensor=self.ent_embeddings.weight.data, | |||||
a=-self.embedding_range.item(), | |||||
b=self.embedding_range.item(), | |||||
) | |||||
nn.init.uniform_( | |||||
tensor=self.rel_embeddings.weight.data, | |||||
a=-self.embedding_range.item(), | |||||
b=self.embedding_range.item(), | |||||
) |
from __future__ import annotations | |||||
from typing import TYPE_CHECKING | |||||
import torch | |||||
from pykeen.losses import MarginRankingLoss | |||||
from pykeen.models import TransR as PyKEENTransR | |||||
from torch import nn | |||||
from models.base_model import ERModel | |||||
if TYPE_CHECKING: | |||||
from pykeen.typing import MappedTriples | |||||
class TransR(ERModel): | |||||
""" | |||||
TransR model for knowledge graph embedding. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
num_entities: int, | |||||
num_relations: int, | |||||
triples_factory: MappedTriples, | |||||
entity_dim: int = 200, | |||||
relation_dim: int = 30, | |||||
p_norm: bool = True, | |||||
p_norm_value: int = 2, | |||||
margin: float | None = None, | |||||
epsilon: float | None = None, | |||||
device: torch.device = torch.device("cpu"), | |||||
**kwargs, | |||||
) -> None: | |||||
super().__init__(num_entities, num_relations, entity_dim) | |||||
self.entity_dim = entity_dim | |||||
self.relation_dim = relation_dim | |||||
self.p_norm = p_norm | |||||
self.margin = margin | |||||
self.epsilon = epsilon | |||||
self.inner_model = PyKEENTransR( | |||||
triples_factory=triples_factory, | |||||
embedding_dim=entity_dim, | |||||
relation_dim=relation_dim, | |||||
scoring_fct_norm=1, | |||||
loss=self.loss, | |||||
power_norm=p_norm, | |||||
entity_initializer="xavier_uniform_", | |||||
relation_initializer="xavier_uniform_", | |||||
random_seed=42, | |||||
).to(device) | |||||
@property | |||||
def loss(self) -> torch.Tensor: | |||||
return MarginRankingLoss(margin=self.margin, reduction="mean") | |||||
def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||||
# Parameter initialization | |||||
if margin is None or epsilon is None: | |||||
# If no margin/epsilon are provided, use Xavier uniform initialization | |||||
nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||||
nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||||
else: | |||||
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||||
self.embedding_range = nn.Parameter( | |||||
torch.Tensor([(margin + epsilon) / self.dim]), | |||||
requires_grad=False, | |||||
) | |||||
nn.init.uniform_( | |||||
tensor=self.ent_embeddings.weight.data, | |||||
a=-self.embedding_range.item(), | |||||
b=self.embedding_range.item(), | |||||
) | |||||
nn.init.uniform_( | |||||
tensor=self.rel_embeddings.weight.data, | |||||
a=-self.embedding_range.item(), | |||||
b=self.embedding_range.item(), | |||||
) |
[tool.ruff] | |||||
# Exclude a variety of commonly ignored directories. | |||||
exclude = [ | |||||
".bzr", | |||||
".direnv", | |||||
".eggs", | |||||
".git", | |||||
".git-rewrite", | |||||
".hg", | |||||
".ipynb_checkpoints", | |||||
".mypy_cache", | |||||
".nox", | |||||
".pants.d", | |||||
".pyenv", | |||||
".pytest_cache", | |||||
".pytype", | |||||
".ruff_cache", | |||||
".svn", | |||||
".tox", | |||||
".venv", | |||||
".vscode", | |||||
"__pypackages__", | |||||
"_build", | |||||
"buck-out", | |||||
"build", | |||||
"dist", | |||||
"node_modules", | |||||
"site-packages", | |||||
"venv", | |||||
] | |||||
respect-gitignore = true | |||||
# Black line length is 88 | |||||
line-length = 100 | |||||
indent-width = 4 | |||||
[tool.ruff.lint] | |||||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. | |||||
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or | |||||
# McCabe complexity (`C901`) by default. | |||||
select = ["ALL"] | |||||
ignore = ["ANN002", "ANN003", "RET504", "ERA001", "E741", "N812", "D", "PLR0913", "T201", "FBT001", "FBT002", "PLW0602", "PLW0603", "PTH", "C901", "PLR2004", "UP006", "UP035", "G004", "N803", "N806", "ARG002"] | |||||
[tool.ruff.lint.per-file-ignores] | |||||
"__init__.py" = ["E", "F", "I", "N"] | |||||
# Allow fix for all enabled rules (when `--fix`) is provided. | |||||
fixable = ["ALL"] | |||||
unfixable = [] | |||||
# Allow unused variables when underscore-prefixed. | |||||
#dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" | |||||
[tool.ruff.format] | |||||
# Like Black, use double quotes for strings. | |||||
quote-style = "double" | |||||
# Like Black, indent with spaces, rather than tabs. | |||||
indent-style = "space" | |||||
# Like Black, respect magic trailing commas. | |||||
skip-magic-trailing-comma = false | |||||
# Like Black, automatically detect the appropriate line ending. | |||||
line-ending = "auto" | |||||
# Enable auto-formatting of code examples in docstrings. Markdown, | |||||
# reStructuredText code/literal blocks and doctests are all supported. | |||||
# | |||||
# This is currently disabled by default, but it is planned for this | |||||
# to be opt-out in the future. | |||||
docstring-code-format = false | |||||
# Set the line length limit used when formatting code snippets in | |||||
# docstrings. | |||||
# | |||||
# This only has an effect when the `docstring-code-format` setting is | |||||
# enabled. | |||||
docstring-code-line-length = "dynamic" | |||||
[tool.pyright] | |||||
typeCheckingMode = "off" |
#!/usr/bin/env bash | |||||
# install.sh – minimal Conda bootstrap with ~/.bashrc check | |||||
set -e | |||||
ENV=kg-env # must match your environment.yml | |||||
YAML=kg_env.yaml # assumed to be in the current dir | |||||
# ── 1. try to load an existing conda from ~/.bashrc ────────────────────── | |||||
[ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" || true | |||||
# ── 2. ensure conda exists ─────────────────────────────────────────────── | |||||
if ! command -v conda &>/dev/null; then | |||||
echo "[install.sh] Conda not found → installing Miniconda." | |||||
curl -fsSL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh | |||||
bash miniconda.sh -b -p "$HOME/miniconda3" | |||||
rm miniconda.sh | |||||
source "$HOME/miniconda3/etc/profile.d/conda.sh" | |||||
conda init bash >/dev/null # so future shells have it | |||||
else | |||||
# conda exists; load its helper for this shell | |||||
source "$(conda info --base)/etc/profile.d/conda.sh" | |||||
fi | |||||
# ── 3. create or update the project env ────────────────────────────────── | |||||
conda env create -n "$ENV" -f "$YAML" \ | |||||
|| conda env update -n "$ENV" -f "$YAML" --prune | |||||
echo "✔ Done. Activate with: conda activate $ENV" |
# environment.yml ── KG-Embeddings project | |||||
name: kg | |||||
channels: | |||||
- pytorch | |||||
- defaults | |||||
- nvidia | |||||
- conda-forge | |||||
- https://repo.anaconda.com/pkgs/main | |||||
- https://repo.anaconda.com/pkgs/r | |||||
dependencies: | |||||
- _libgcc_mutex=0.1=conda_forge | |||||
- _openmp_mutex=4.5=2_gnu | |||||
- aom=3.6.1=h59595ed_0 | |||||
- blas=1.0=mkl | |||||
- brotli-python=1.1.0=py311hfdbb021_2 | |||||
- bzip2=1.0.8=h4bc722e_7 | |||||
- ca-certificates=2024.12.14=hbcca054_0 | |||||
- certifi=2024.12.14=pyhd8ed1ab_0 | |||||
- cffi=1.17.1=py311hf29c0ef_0 | |||||
- charset-normalizer=3.4.1=pyhd8ed1ab_0 | |||||
- cpython=3.11.11=py311hd8ed1ab_1 | |||||
- cuda-cudart=12.4.127=0 | |||||
- cuda-cupti=12.4.127=0 | |||||
- cuda-libraries=12.4.1=0 | |||||
- cuda-nvrtc=12.4.127=0 | |||||
- cuda-nvtx=12.4.127=0 | |||||
- cuda-opencl=12.6.77=0 | |||||
- cuda-runtime=12.4.1=0 | |||||
- cuda-version=12.6=3 | |||||
- ffmpeg=4.4.2=gpl_hdf48244_113 | |||||
- filelock=3.16.1=pyhd8ed1ab_1 | |||||
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0 | |||||
- font-ttf-inconsolata=3.000=h77eed37_0 | |||||
- font-ttf-source-code-pro=2.038=h77eed37_0 | |||||
- font-ttf-ubuntu=0.83=h77eed37_3 | |||||
- fontconfig=2.15.0=h7e30c49_1 | |||||
- fonts-conda-ecosystem=1=0 | |||||
- fonts-conda-forge=1=0 | |||||
- freetype=2.12.1=h267a509_2 | |||||
- gettext=0.22.5=he02047a_3 | |||||
- gettext-tools=0.22.5=he02047a_3 | |||||
- giflib=5.2.2=hd590300_0 | |||||
- gmp=6.3.0=hac33072_2 | |||||
- gmpy2=2.1.5=py311h0f6cedb_3 | |||||
- gnutls=3.7.9=hb077bed_0 | |||||
- h2=4.1.0=pyhd8ed1ab_1 | |||||
- hpack=4.0.0=pyhd8ed1ab_1 | |||||
- hyperframe=6.0.1=pyhd8ed1ab_1 | |||||
- idna=3.10=pyhd8ed1ab_1 | |||||
- intel-openmp=2022.0.1=h06a4308_3633 | |||||
- jinja2=3.1.5=pyhd8ed1ab_0 | |||||
- lame=3.100=h166bdaf_1003 | |||||
- lcms2=2.16=hb7c19ff_0 | |||||
- ld_impl_linux-64=2.43=h712a8e2_2 | |||||
- lerc=4.0.0=h27087fc_0 | |||||
- libasprintf=0.22.5=he8f35ee_3 | |||||
- libasprintf-devel=0.22.5=he8f35ee_3 | |||||
- libblas=3.9.0=16_linux64_mkl | |||||
- libcblas=3.9.0=16_linux64_mkl | |||||
- libcublas=12.4.5.8=0 | |||||
- libcufft=11.2.1.3=0 | |||||
- libcufile=1.11.1.6=0 | |||||
- libcurand=10.3.7.77=0 | |||||
- libcusolver=11.6.1.9=0 | |||||
- libcusparse=12.3.1.170=0 | |||||
- libdeflate=1.23=h4ddbbb0_0 | |||||
- libdrm=2.4.124=hb9d3cd8_0 | |||||
- libegl=1.7.0=ha4b6fd6_2 | |||||
- libexpat=2.6.4=h5888daf_0 | |||||
- libffi=3.4.2=h7f98852_5 | |||||
- libgcc=14.2.0=h77fa898_1 | |||||
- libgcc-ng=14.2.0=h69a702a_1 | |||||
- libgettextpo=0.22.5=he02047a_3 | |||||
- libgettextpo-devel=0.22.5=he02047a_3 | |||||
- libgl=1.7.0=ha4b6fd6_2 | |||||
- libglvnd=1.7.0=ha4b6fd6_2 | |||||
- libglx=1.7.0=ha4b6fd6_2 | |||||
- libgomp=14.2.0=h77fa898_1 | |||||
- libiconv=1.17=hd590300_2 | |||||
- libidn2=2.3.7=hd590300_0 | |||||
- libjpeg-turbo=3.0.0=hd590300_1 | |||||
- liblapack=3.9.0=16_linux64_mkl | |||||
- liblzma=5.6.3=hb9d3cd8_1 | |||||
- libnpp=12.2.5.30=0 | |||||
- libnsl=2.0.1=hd590300_0 | |||||
- libnvfatbin=12.6.77=0 | |||||
- libnvjitlink=12.4.127=0 | |||||
- libnvjpeg=12.3.1.117=0 | |||||
- libpciaccess=0.18=hd590300_0 | |||||
- libpng=1.6.45=h943b412_0 | |||||
- libsqlite=3.47.2=hee588c1_0 | |||||
- libstdcxx=14.2.0=hc0a3c3a_1 | |||||
- libstdcxx-ng=14.2.0=h4852527_1 | |||||
- libtasn1=4.19.0=h166bdaf_0 | |||||
- libtiff=4.7.0=hd9ff511_3 | |||||
- libunistring=0.9.10=h7f98852_0 | |||||
- libuuid=2.38.1=h0b41bf4_0 | |||||
- libva=2.22.0=h8a09558_1 | |||||
- libvpx=1.13.1=h59595ed_0 | |||||
- libwebp=1.5.0=hae8dbeb_0 | |||||
- libwebp-base=1.5.0=h851e524_0 | |||||
- libxcb=1.17.0=h8a09558_0 | |||||
- libxcrypt=4.4.36=hd590300_1 | |||||
- libxml2=2.13.5=h0d44e9d_1 | |||||
- libzlib=1.3.1=hb9d3cd8_2 | |||||
- llvm-openmp=15.0.7=h0cdce71_0 | |||||
- markupsafe=3.0.2=py311h2dc5d0c_1 | |||||
- mkl=2022.1.0=hc2b9512_224 | |||||
- mpc=1.3.1=h24ddda3_1 | |||||
- mpfr=4.2.1=h90cbb55_3 | |||||
- mpmath=1.3.0=pyhd8ed1ab_1 | |||||
- ncurses=6.5=h2d0b736_2 | |||||
- nettle=3.9.1=h7ab15ed_0 | |||||
- networkx=3.4.2=pyh267e887_2 | |||||
- numpy=2.2.1=py311hf916aec_0 | |||||
- openh264=2.3.1=hcb278e6_2 | |||||
- openjpeg=2.5.3=h5fbd93e_0 | |||||
- openssl=3.4.0=h7b32b05_1 | |||||
- p11-kit=0.24.1=hc5aa10d_0 | |||||
- pillow=11.1.0=py311h1322bbf_0 | |||||
- pip=24.3.1=pyh8b19718_2 | |||||
- pthread-stubs=0.4=hb9d3cd8_1002 | |||||
- pycparser=2.22=pyh29332c3_1 | |||||
- pysocks=1.7.1=pyha55dd90_7 | |||||
- python=3.11.11=h9e4cc4f_1_cpython | |||||
- python_abi=3.11=5_cp311 | |||||
- pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0 | |||||
- pytorch-cuda=12.4=hc786d27_7 | |||||
- pytorch-mutex=1.0=cuda | |||||
- pyyaml=6.0.2=py311h9ecbd09_1 | |||||
- readline=8.2=h8228510_1 | |||||
- requests=2.32.3=pyhd8ed1ab_1 | |||||
- setuptools=75.8.0=pyhff2d567_0 | |||||
- svt-av1=1.4.1=hcb278e6_0 | |||||
- sympy=1.13.3=pyh2585a3b_105 | |||||
- tk=8.6.13=noxft_h4845f30_101 | |||||
- torchaudio=2.5.1=py311_cu124 | |||||
- torchtriton=3.1.0=py311 | |||||
- torchvision=0.20.1=py311_cu124 | |||||
- typing_extensions=4.12.2=pyha770c72_1 | |||||
- tzdata=2024b=hc8b5060_0 | |||||
- urllib3=2.3.0=pyhd8ed1ab_0 | |||||
- wayland=1.23.1=h3e06ad9_0 | |||||
- wayland-protocols=1.37=hd8ed1ab_0 | |||||
- wheel=0.45.1=pyhd8ed1ab_1 | |||||
- x264=1!164.3095=h166bdaf_2 | |||||
- x265=3.5=h924138e_3 | |||||
- xorg-libx11=1.8.10=h4f16b4b_1 | |||||
- xorg-libxau=1.0.12=hb9d3cd8_0 | |||||
- xorg-libxdmcp=1.1.5=hb9d3cd8_0 | |||||
- xorg-libxext=1.3.6=hb9d3cd8_0 | |||||
- xorg-libxfixes=6.0.1=hb9d3cd8_0 | |||||
- yaml=0.2.5=h7f98852_2 | |||||
- zstandard=0.23.0=py311hbc35293_1 | |||||
- zstd=1.5.6=ha6fb4c9_0 | |||||
- pip: | |||||
- torchmetrics>=1.4 # nicer MRR/Hits@K utilities | |||||
- lovely-tensors>=0.1.0 # tensor utilities | |||||
- lovely-numpy>=0.1.0 # numpy utilities | |||||
- einops==0.8.1 | |||||
- pykeen==1.11.1 | |||||
- hydra-core==1.3.2 | |||||
- tensorboard==2.19.0 |
from .pretty_logger import get_pretty_logger | |||||
from .sampling import DeterministicRandomSampler | |||||
from .utils import set_seed | |||||
from .tb_handler import TensorBoardHandler | |||||
from .params import CommonParams, TrainingParams | |||||
from .checkpoint_manager import CheckpointManager |
from __future__ import annotations | |||||
import re | |||||
from pathlib import Path | |||||
class CheckpointManager: | |||||
""" | |||||
A checkpoint manager that handles checkpoint directory creation and path resolution. | |||||
Features: | |||||
- Creates sequentially numbered checkpoint directories | |||||
- Supports two loading modes: by components or by full path | |||||
""" | |||||
def __init__( | |||||
self, | |||||
root_directory: str | Path, | |||||
run_name: str, | |||||
load_only: bool = False, | |||||
) -> None: | |||||
""" | |||||
Initialize the checkpoint manager. | |||||
Args: | |||||
root_directory: Root directory for checkpoints | |||||
run_name: Name of the run (used as suffix for checkpoint directories) | |||||
""" | |||||
self.root_directory = Path(root_directory) | |||||
self.run_name = run_name | |||||
self.checkpoint_directory = None | |||||
if not load_only: | |||||
self.root_directory.mkdir(parents=True, exist_ok=True) | |||||
self.checkpoint_directory = self._create_checkpoint_directory() | |||||
else: | |||||
self.checkpoint_directory = "" | |||||
def _find_existing_directories(self) -> list[int]: | |||||
""" | |||||
Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number. | |||||
Returns: | |||||
List of existing sequence numbers | |||||
""" | |||||
pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$") | |||||
existing_numbers = [] | |||||
if self.root_directory.exists(): | |||||
for item in self.root_directory.iterdir(): | |||||
if item.is_dir(): | |||||
match = pattern.match(item.name) | |||||
if match: | |||||
existing_numbers.append(int(match.group(1))) | |||||
return sorted(existing_numbers) | |||||
def _create_checkpoint_directory(self) -> Path: | |||||
""" | |||||
Create a new checkpoint directory with the next sequential number. | |||||
Returns: | |||||
Path to the created checkpoint directory | |||||
""" | |||||
existing_numbers = self._find_existing_directories() | |||||
# Determine the next number | |||||
next_number = max(existing_numbers) + 1 if existing_numbers else 1 | |||||
# Create directory name with 3-digit zero-padded number | |||||
dir_name = f"{next_number:03d}_{self.run_name}" | |||||
checkpoint_dir = self.root_directory / dir_name | |||||
self.run_name = dir_name | |||||
# Create the directory | |||||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | |||||
return checkpoint_dir | |||||
def get_checkpoint_directory(self) -> Path: | |||||
""" | |||||
Get the current checkpoint directory. | |||||
Returns: | |||||
Path to the checkpoint directory | |||||
""" | |||||
return self.checkpoint_directory | |||||
def get_model_fpath(self, model_path: str) -> Path: | |||||
""" | |||||
Get the full path to a model checkpoint file. | |||||
Args: | |||||
model_path: Either a tuple (model_id, iteration) or a string path | |||||
Returns: | |||||
Full path to the model checkpoint file | |||||
""" | |||||
try: | |||||
model_path = eval(model_path) | |||||
if isinstance(model_path, tuple): | |||||
model_id, iteration = model_path | |||||
checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}" | |||||
filename = f"model-{iteration}.pt" | |||||
return Path(f"{checkpoint_directory}/{filename}") | |||||
except (SyntaxError, NameError): | |||||
pass | |||||
if isinstance(model_path, str): | |||||
return Path(model_path) | |||||
msg = "model_path must be a tuple (model_id, iteration) or a string path" | |||||
raise ValueError(msg) | |||||
def get_model_path_by_args( | |||||
self, | |||||
model_id: str | None = None, | |||||
iteration: int | str | None = None, | |||||
full_path: str | Path | None = None, | |||||
) -> Path: | |||||
""" | |||||
Get the path for loading a model checkpoint. | |||||
Two modes of operation: | |||||
1. Component mode: Provide model_id and iteration to construct path | |||||
2. Full path mode: Provide full_path directly | |||||
Args: | |||||
model_id: Model identifier (used in component mode) | |||||
iteration: Training iteration (used in component mode) | |||||
full_path: Full path to checkpoint (used in full path mode) | |||||
Returns: | |||||
Path to the checkpoint file | |||||
Raises: | |||||
ValueError: If neither component parameters nor full_path are provided, | |||||
or if both modes are attempted simultaneously | |||||
""" | |||||
# Check which mode we're operating in | |||||
component_mode = model_id is not None or iteration is not None | |||||
full_path_mode = full_path is not None | |||||
if component_mode and full_path_mode: | |||||
msg = ( | |||||
"Cannot use both component mode (model_id/iteration) " | |||||
"and full path mode simultaneously" | |||||
) | |||||
raise ValueError(msg) | |||||
if not component_mode and not full_path_mode: | |||||
msg = "Must provide either (model_id and iteration) or full_path" | |||||
raise ValueError(msg) | |||||
if full_path_mode: | |||||
# Full path mode: return the path as-is | |||||
return Path(full_path) | |||||
# Component mode: construct path from checkpoint directory, model_id, and iteration | |||||
if model_id is None or iteration is None: | |||||
msg = "Both model_id and iteration must be provided in component mode" | |||||
raise ValueError(msg) | |||||
filename = f"{model_id}_iter_{iteration}.pt" | |||||
return self.checkpoint_directory / filename | |||||
def save_checkpoint_info(self, info: dict) -> None: | |||||
""" | |||||
Save checkpoint information to a JSON file in the checkpoint directory. | |||||
Args: | |||||
info: Dictionary containing checkpoint metadata | |||||
""" | |||||
import json | |||||
info_file = self.checkpoint_directory / "checkpoint_info.json" | |||||
with open(info_file, "w") as f: | |||||
json.dump(info, f, indent=2) | |||||
def load_checkpoint_info(self) -> dict: | |||||
""" | |||||
Load checkpoint information from the checkpoint directory. | |||||
Returns: | |||||
Dictionary containing checkpoint metadata | |||||
""" | |||||
import json | |||||
info_file = self.checkpoint_directory / "checkpoint_info.json" | |||||
if info_file.exists(): | |||||
with open(info_file) as f: | |||||
return json.load(f) | |||||
return {} | |||||
def __str__(self) -> str: | |||||
return ( | |||||
f"CheckpointManager(root='{self.root_directory}', " | |||||
f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')" | |||||
) | |||||
def __repr__(self) -> str: | |||||
return self.__str__() | |||||
# Example usage: | |||||
if __name__ == "__main__": | |||||
import tempfile | |||||
# Create checkpoint manager with proper temp directory | |||||
with tempfile.TemporaryDirectory() as temp_dir: | |||||
ckpt_manager = CheckpointManager(temp_dir, "my_experiment") | |||||
print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}") | |||||
# Component mode - construct path from components | |||||
model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000) | |||||
print(f"Model path (component mode): {model_path}") | |||||
# Full path mode - use existing full path | |||||
full_path = "/some/other/path/model.pt" | |||||
model_path = ckpt_manager.get_model_path_by_args(full_path=full_path) | |||||
print(f"Model path (full path mode): {model_path}") |
from __future__ import annotations | |||||
from dataclasses import dataclass | |||||
@dataclass | |||||
class CommonParams: | |||||
""" | |||||
Common parameters for the application. | |||||
""" | |||||
model_name: str | |||||
project_name: str | |||||
run_name: str | |||||
save_dpath: str | |||||
save_every: int = 1000 | |||||
load_path: str | None = "" | |||||
log_dir: str = "./runs" | |||||
log_every: int = 10 | |||||
log_console_every: int = 100 | |||||
evaluate_only: bool = False | |||||
@dataclass | |||||
class TrainingParams: | |||||
""" | |||||
Parameters for training. | |||||
""" | |||||
lr: float = 0.001 | |||||
min_lr: float = 0.0001 | |||||
weight_decay: float = 0.0 | |||||
t0: int = 50 | |||||
lr_step: int = 10 | |||||
gamma: float = 1.0 | |||||
batch_size: int = 128 | |||||
num_workers: int = 0 | |||||
seed: int = 42 | |||||
num_train_steps: int = 100 | |||||
num_epochs: int = 100 | |||||
eval_every: int = 5 | |||||
validation_sample_size: int = 1000 | |||||
validation_batch_size: int = 128 |
# pretty_logger.py | |||||
import json | |||||
import logging | |||||
import logging.handlers | |||||
import sys | |||||
import threading | |||||
import time | |||||
from contextlib import ContextDecorator | |||||
from functools import wraps | |||||
# ANSI escape codes for colors | |||||
RESET = "\x1b[0m" | |||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = ( | |||||
"\x1b[30m", | |||||
"\x1b[31m", | |||||
"\x1b[32m", | |||||
"\x1b[33m", | |||||
"\x1b[34m", | |||||
"\x1b[35m", | |||||
"\x1b[36m", | |||||
"\x1b[37m", | |||||
) | |||||
_LEVEL_COLORS = { | |||||
logging.DEBUG: CYAN, | |||||
logging.INFO: GREEN, | |||||
logging.WARNING: YELLOW, | |||||
logging.ERROR: RED, | |||||
logging.CRITICAL: MAGENTA, | |||||
} | |||||
def _supports_color() -> bool: | |||||
""" | |||||
Detect whether the running terminal supports ANSI colors. | |||||
""" | |||||
if hasattr(sys.stdout, "isatty") and sys.stdout.isatty(): | |||||
platform = sys.platform | |||||
if platform == "win32": | |||||
# On Windows, modern terminals may support ANSI; assume yes for Windows 10+ | |||||
return True | |||||
return True | |||||
return False | |||||
class PrettyFormatter(logging.Formatter): | |||||
""" | |||||
A logging.Formatter that outputs colorized logs to the console | |||||
(if enabled) and supports JSON formatting (if requested). | |||||
""" | |||||
def __init__( | |||||
self, | |||||
fmt: str = None, | |||||
datefmt: str = "%Y-%m-%d %H:%M:%S", | |||||
use_colors: bool = True, | |||||
json_output: bool = False, | |||||
): | |||||
super().__init__(fmt=fmt, datefmt=datefmt) | |||||
self.use_colors = use_colors and _supports_color() | |||||
self.json_output = json_output | |||||
def format(self, record: logging.LogRecord) -> str: | |||||
# If JSON output is requested, dump a dict of relevant fields | |||||
if self.json_output: | |||||
log_record = { | |||||
"timestamp": self.formatTime(record, self.datefmt), | |||||
"name": record.name, | |||||
"level": record.levelname, | |||||
"message": record.getMessage(), | |||||
} | |||||
# Include extra/contextual fields | |||||
for k, v in record.__dict__.get("extra_fields", {}).items(): | |||||
log_record[k] = v | |||||
if record.exc_info: | |||||
log_record["exception"] = self.formatException(record.exc_info) | |||||
return json.dumps(log_record, ensure_ascii=False) | |||||
# Otherwise, build a colorized/“pretty” line | |||||
level_color = _LEVEL_COLORS.get(record.levelno, WHITE) if self.use_colors else "" | |||||
reset = RESET if self.use_colors else "" | |||||
timestamp = self.formatTime(record, self.datefmt) | |||||
name = record.name | |||||
level = record.levelname | |||||
# Basic format: [timestamp] [LEVEL] [name]: message | |||||
base = f"[{timestamp}] [{level}] [{name}]: {record.getMessage()}" | |||||
# If there are any extra fields attached via LoggerAdapter, append them | |||||
extra_fields = record.__dict__.get("extra_fields", {}) | |||||
if extra_fields: | |||||
# key1=val1 key2=val2 … | |||||
extras_str = " ".join(f"{k}={v}" for k, v in extra_fields.items()) | |||||
base = f"{base} :: {extras_str}" | |||||
# If exception info is present, include traceback | |||||
if record.exc_info: | |||||
exc_text = self.formatException(record.exc_info) | |||||
base = f"{base}\n{exc_text}" | |||||
if self.use_colors: | |||||
return f"{level_color}{base}{reset}" | |||||
else: | |||||
return base | |||||
class JsonFormatter(PrettyFormatter): | |||||
""" | |||||
Shortcut to force JSON formatting (ignores color flags). | |||||
""" | |||||
def __init__(self, datefmt: str = "%Y-%m-%d %H:%M:%S"): | |||||
super().__init__(fmt=None, datefmt=datefmt, use_colors=False, json_output=True) | |||||
class ContextFilter(logging.Filter): | |||||
""" | |||||
Inject extra_fields from a LoggerAdapter into the LogRecord, so the Formatter can find them. | |||||
""" | |||||
def filter(self, record: logging.LogRecord) -> bool: | |||||
extra = getattr(record, "extra", None) | |||||
if isinstance(extra, dict): | |||||
record.extra_fields = extra | |||||
else: | |||||
record.extra_fields = {} | |||||
return True | |||||
class Timer(ContextDecorator): | |||||
""" | |||||
Context manager & decorator to measure execution time of a code block or function. | |||||
Usage as context manager: | |||||
with Timer(logger, "block name"): | |||||
# code … | |||||
Usage as decorator: | |||||
@Timer(logger, "function foo") | |||||
def foo(...): | |||||
... | |||||
""" | |||||
def __init__(self, logger: logging.Logger, name: str = None, level: int = logging.INFO): | |||||
self.logger = logger | |||||
self.name = name or "" | |||||
self.level = level | |||||
def __enter__(self): | |||||
self.start_time = time.time() | |||||
return self | |||||
def __exit__(self, exc_type, exc, exc_tb): | |||||
elapsed = time.time() - self.start_time | |||||
label = f"[Timer:{self.name}]" if self.name else "[Timer]" | |||||
self.logger.log(self.level, f"{label} elapsed {elapsed:.4f} sec") | |||||
return False # do not suppress exceptions | |||||
def log_function(logger: logging.Logger, level: int = logging.DEBUG): | |||||
""" | |||||
Decorator factory to log function entry/exit and execution time. | |||||
Usage: | |||||
@log_function(logger, level=logging.INFO) | |||||
def my_func(...): | |||||
... | |||||
""" | |||||
def decorator(func): | |||||
@wraps(func) | |||||
def wrapper(*args, **kwargs): | |||||
logger.log(level, f"➤ Entering {func.__name__}()") | |||||
start = time.time() | |||||
try: | |||||
result = func(*args, **kwargs) | |||||
except Exception: | |||||
logger.exception(f"✖ Exception in {func.__name__}()") | |||||
raise | |||||
elapsed = time.time() - start | |||||
logger.log(level, f"✔ Exiting {func.__name__}(), elapsed {elapsed:.4f} sec") | |||||
return result | |||||
return wrapper | |||||
return decorator | |||||
class PrettyLogger: | |||||
""" | |||||
Encapsulates a logging.Logger with “pretty” formatting, colorized console output, | |||||
rotating file handler, JSON option, contextual fields, etc. | |||||
""" | |||||
_instances = {} | |||||
_lock = threading.Lock() | |||||
def __new__(cls, name: str = "root", **kwargs): | |||||
""" | |||||
Implements a simple singleton: multiple calls with the same name return the same instance. | |||||
""" | |||||
with cls._lock: | |||||
if name not in cls._instances: | |||||
instance = super().__new__(cls) | |||||
cls._instances[name] = instance | |||||
return cls._instances[name] | |||||
def __init__( | |||||
self, | |||||
name: str = "root", | |||||
level: int = logging.DEBUG, | |||||
log_to_console: bool = True, | |||||
console_level: int = logging.DEBUG, | |||||
log_to_file: bool = False, | |||||
log_file: str = "app.log", | |||||
file_level: int = logging.INFO, | |||||
max_bytes: int = 10 * 1024 * 1024, # 10 MB | |||||
backup_count: int = 5, | |||||
formatter: PrettyFormatter = None, | |||||
json_output: bool = False, | |||||
): | |||||
""" | |||||
Initialize (or re‐configure) a PrettyLogger. | |||||
Args: | |||||
name: logger name. | |||||
level: root level for the logger (lowest level that will be processed). | |||||
log_to_console: whether to attach a console (StreamHandler). | |||||
console_level: level for console output. | |||||
log_to_file: whether to attach a rotating file handler. | |||||
log_file: path to the log file. | |||||
file_level: level for file output. | |||||
max_bytes: rotation threshold in bytes. | |||||
backup_count: number of backup files to keep. | |||||
formatter: an instance of PrettyFormatter. If None, a default is created. | |||||
json_output: if True, forces JSON formatting on all handlers. | |||||
""" | |||||
# Avoid re‐initializing if already done | |||||
logger = logging.getLogger(name) | |||||
if getattr(logger, "_pretty_logger_inited", False): | |||||
return | |||||
self.logger = logger | |||||
self.logger.setLevel(level) | |||||
self.logger.propagate = False # avoid double‐logging if root logger is configured elsewhere | |||||
# Create a default formatter if not supplied | |||||
if formatter is None: | |||||
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||||
formatter = PrettyFormatter( | |||||
fmt=fmt, use_colors=not json_output, json_output=json_output | |||||
) | |||||
# Add ContextFilter so extra_fields is always present | |||||
context_filter = ContextFilter() | |||||
self.logger.addFilter(context_filter) | |||||
# Console handler | |||||
if log_to_console: | |||||
ch = logging.StreamHandler(sys.stdout) | |||||
ch.setLevel(console_level) | |||||
ch.setFormatter(formatter) | |||||
self.logger.addHandler(ch) | |||||
# Rotating file handler | |||||
if log_to_file: | |||||
fh = logging.handlers.RotatingFileHandler( | |||||
filename=log_file, | |||||
maxBytes=max_bytes, | |||||
backupCount=backup_count, | |||||
encoding="utf-8", | |||||
) | |||||
fh.setLevel(file_level) | |||||
# For file, typically we don’t colorize | |||||
file_formatter = ( | |||||
JsonFormatter(datefmt="%Y-%m-%d %H:%M:%S") | |||||
if json_output | |||||
else PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||||
) | |||||
fh.setFormatter(file_formatter) | |||||
self.logger.addHandler(fh) | |||||
# Mark as initialized | |||||
setattr(self.logger, "_pretty_logger_inited", True) | |||||
def addHandler(self, handler: logging.Handler): | |||||
""" | |||||
Add a custom handler to the logger. | |||||
This can be used to add any logging.Handler instance. | |||||
""" | |||||
self.logger.addHandler(handler) | |||||
def bind(self, **kwargs): | |||||
""" | |||||
Return a LoggerAdapter that automatically injects extra fields into all log records. | |||||
Example: | |||||
ctx_logger = pretty_logger.bind(request_id=123, user="alice") | |||||
ctx_logger.info("Something happened") # will include request_id and user in the output | |||||
""" | |||||
return logging.LoggerAdapter(self.logger, {"extra": kwargs}) | |||||
def add_stream_handler( | |||||
self, stream=None, level: int = logging.DEBUG, formatter: PrettyFormatter = None | |||||
): | |||||
""" | |||||
Add an additional stream handler (e.g. to stderr). | |||||
""" | |||||
handler = logging.StreamHandler(stream or sys.stdout) | |||||
handler.setLevel(level) | |||||
if formatter is None: | |||||
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||||
formatter = PrettyFormatter(fmt=fmt) | |||||
handler.setFormatter(formatter) | |||||
self.logger.addHandler(handler) | |||||
return handler | |||||
def add_rotating_file_handler( | |||||
self, | |||||
log_file: str, | |||||
level: int = logging.INFO, | |||||
max_bytes: int = 10 * 1024 * 1024, | |||||
backup_count: int = 5, | |||||
json_output: bool = False, | |||||
): | |||||
""" | |||||
Add another rotating file handler at runtime. | |||||
""" | |||||
fh = logging.handlers.RotatingFileHandler( | |||||
filename=log_file, | |||||
maxBytes=max_bytes, | |||||
backupCount=backup_count, | |||||
encoding="utf-8", | |||||
) | |||||
fh.setLevel(level) | |||||
if json_output: | |||||
formatter = JsonFormatter() | |||||
else: | |||||
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||||
formatter = PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||||
fh.setFormatter(formatter) | |||||
self.logger.addHandler(fh) | |||||
return fh | |||||
def remove_handler(self, handler: logging.Handler): | |||||
""" | |||||
Remove a handler from the logger. | |||||
""" | |||||
self.logger.removeHandler(handler) | |||||
def set_level(self, level: int): | |||||
""" | |||||
Change the logger’s root level at runtime. | |||||
""" | |||||
self.logger.setLevel(level) | |||||
def get_logger(self) -> logging.Logger: | |||||
return self.logger | |||||
# Convenience methods to expose common logging calls directly from this object: | |||||
def debug(self, msg, *args, **kwargs): | |||||
self.logger.debug(msg, *args, **kwargs) | |||||
def info(self, msg, *args, **kwargs): | |||||
self.logger.info(msg, *args, **kwargs) | |||||
def warning(self, msg, *args, **kwargs): | |||||
self.logger.warning(msg, *args, **kwargs) | |||||
def error(self, msg, *args, **kwargs): | |||||
self.logger.error(msg, *args, **kwargs) | |||||
def critical(self, msg, *args, **kwargs): | |||||
self.logger.critical(msg, *args, **kwargs) | |||||
def exception(self, msg, *args, exc_info=True, **kwargs): | |||||
""" | |||||
Log an error message with exception traceback. By default exc_info=True. | |||||
""" | |||||
self.logger.error(msg, *args, exc_info=exc_info, **kwargs) | |||||
# Convenience function for a module‐level “global” pretty logger: | |||||
_global_loggers = {} | |||||
_global_lock = threading.Lock() | |||||
def get_pretty_logger( | |||||
name: str = "root", | |||||
level: int = logging.DEBUG, | |||||
log_to_console: bool = True, | |||||
console_level: int = logging.DEBUG, | |||||
log_to_file: bool = False, | |||||
log_file: str = "app.log", | |||||
file_level: int = logging.INFO, | |||||
max_bytes: int = 10 * 1024 * 1024, | |||||
backup_count: int = 5, | |||||
json_output: bool = False, | |||||
) -> PrettyLogger: | |||||
""" | |||||
Return a PrettyLogger instance for `name`, creating/configuring it on first use. | |||||
Subsequent calls with the same `name` will return the same logger (singleton‐style). | |||||
""" | |||||
with _global_lock: | |||||
if name not in _global_loggers: | |||||
pl = PrettyLogger( | |||||
name=name, | |||||
level=level, | |||||
log_to_console=log_to_console, | |||||
console_level=console_level, | |||||
log_to_file=log_to_file, | |||||
log_file=log_file, | |||||
file_level=file_level, | |||||
max_bytes=max_bytes, | |||||
backup_count=backup_count, | |||||
json_output=json_output, | |||||
) | |||||
_global_loggers[name] = pl | |||||
return _global_loggers[name] |
import torch | |||||
import numpy as np | |||||
from torch.utils.data.sampler import Sampler | |||||
def negative_sampling( | |||||
pos: torch.Tensor, | |||||
num_entities: int, | |||||
num_negatives: int = 1, | |||||
mode: list = ["head", "tail"], | |||||
) -> torch.Tensor: | |||||
pos = pos.repeat_interleave(num_negatives, dim=0) | |||||
if "head" in mode: | |||||
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||||
pos[:, 0] = neg | |||||
if "tail" in mode: | |||||
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||||
pos[:, 2] = neg | |||||
return pos | |||||
class DeterministicRandomSampler(Sampler): | |||||
""" | |||||
A deterministic random sampler that selects a fixed number of samples | |||||
in a reproducible manner using a given seed. | |||||
""" | |||||
def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None: | |||||
""" | |||||
Args: | |||||
dataset (Dataset): PyTorch dataset. | |||||
sample_size (int, optional): Number of samples to draw. If None, use full dataset. | |||||
seed (int): Seed for reproducibility. | |||||
""" | |||||
self.dataset_size = dataset_size | |||||
self.seed = seed | |||||
self.sample_size = sample_size if sample_size != 0 else self.dataset_size | |||||
# Ensure sample size is within dataset size | |||||
if self.sample_size > self.dataset_size: | |||||
msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}." | |||||
raise ValueError( | |||||
msg, | |||||
) | |||||
self.indices = self._generate_deterministic_indices() | |||||
def _generate_deterministic_indices(self) -> list[int]: | |||||
""" | |||||
Generates a fixed random subset of indices using the given seed. | |||||
""" | |||||
rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility | |||||
all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices | |||||
return all_indices[: self.sample_size].tolist() # Select only the desired number of samples | |||||
def __iter__(self) -> iter: | |||||
""" | |||||
Yields the shuffled dataset indices. | |||||
""" | |||||
return iter(self.indices) | |||||
def __len__(self) -> int: | |||||
""" | |||||
Returns the total number of samples drawn. | |||||
""" | |||||
return self.sample_size |
from __future__ import annotations | |||||
import logging | |||||
from typing import TYPE_CHECKING | |||||
if TYPE_CHECKING: | |||||
from torch.utils.tensorboard import SummaryWriter | |||||
class TensorBoardHandler(logging.Handler): | |||||
"""Convert each log record into a TensorBoard text summary.""" | |||||
def __init__(self, writer: SummaryWriter, tag: str = "logs"): | |||||
super().__init__(level=logging.INFO) | |||||
self.writer = writer | |||||
self.tag = tag | |||||
self._step = 0 # will be incremented every emit | |||||
def emit(self, record: logging.LogRecord) -> None: | |||||
self.writer.add_text(self.tag, self.format(record), global_step=self._step) | |||||
self._step += 1 |
from torch.optim.lr_scheduler import StepLR | |||||
class StepMinLR(StepLR): | |||||
def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1): | |||||
self.min_lr = min_lr | |||||
super().__init__(optimizer, step_size, gamma, last_epoch) | |||||
def get_lr(self): | |||||
# use StepLR's formula, then clamp | |||||
raw_lrs = super().get_lr() | |||||
return [max(lr, self.min_lr) for lr in raw_lrs] |
import random | |||||
from enum import Enum | |||||
import numpy as np | |||||
import torch | |||||
def set_seed(seed: int = 42) -> None: | |||||
""" | |||||
Set the random seed for reproducibility. | |||||
Args: | |||||
seed (int): The seed value to set for random number generation. | |||||
""" | |||||
torch.manual_seed(seed) | |||||
random.seed(seed) | |||||
np.random.default_rng(seed) | |||||
if torch.cuda.is_available(): | |||||
torch.cuda.manual_seed_all(seed) | |||||
class Target(str, Enum): # ↔ PyKEEN's LABEL_* constants | |||||
HEAD = "head" | |||||
TAIL = "tail" | |||||
RELATION = "relation" |
from __future__ import annotations | |||||
from abc import ABC, abstractmethod | |||||
from typing import TYPE_CHECKING | |||||
import torch | |||||
from torch import nn | |||||
if TYPE_CHECKING: | |||||
from torch.optim.lr_scheduler import LRScheduler | |||||
class ModelTrainerBase(ABC): | |||||
def __init__( | |||||
self, | |||||
model: nn.Module, | |||||
optimizer: torch.optim.Optimizer, | |||||
scheduler: LRScheduler | None = None, | |||||
device: torch.device = "cpu", | |||||
) -> None: | |||||
self.model = model.to(device) | |||||
self.optimizer = optimizer | |||||
self.scheduler = scheduler | |||||
self.device = device | |||||
def save(self, save_dpath: str, step: int) -> None: | |||||
""" | |||||
Save the model and optimizer state to the specified directory. | |||||
Args: | |||||
save_dpath (str): The directory path where the model and optimizer state will be saved. | |||||
""" | |||||
data = { | |||||
"model": self.model.state_dict(), | |||||
"optimizer": self.optimizer.state_dict(), | |||||
"step": step, | |||||
} | |||||
if self.scheduler is not None: | |||||
data["scheduler"] = self.scheduler.state_dict() | |||||
torch.save( | |||||
data, | |||||
f"{save_dpath}/model-{step}.pt", | |||||
) | |||||
def load(self, load_fpath: str) -> None: | |||||
""" | |||||
Load the model and optimizer state from the specified directory. | |||||
Args: | |||||
load_dpath (str): The directory path from which the model and | |||||
optimizer state will be loaded. | |||||
""" | |||||
data = torch.load(load_fpath, map_location="cpu") | |||||
model_state_dict = data["model"] | |||||
self.model.load_state_dict(model_state_dict) # move everything to the correct device | |||||
self.model.to(self.device) | |||||
optimizer_state_dict = data["optimizer"] | |||||
self.optimizer.load_state_dict(optimizer_state_dict) | |||||
if self.scheduler is not None: | |||||
scheduler_data = data.get("scheduler", {}) | |||||
self.scheduler.load_state_dict(scheduler_data) | |||||
return data["step"] | |||||
@abstractmethod | |||||
def _loss(self, batch: torch.Tensor) -> torch.Tensor: | |||||
""" | |||||
Compute the loss for a batch of data. | |||||
Args: | |||||
batch (torch.Tensor): A tensor containing a batch of data. | |||||
Returns: | |||||
torch.Tensor: The computed loss for the batch. | |||||
""" | |||||
... | |||||
@abstractmethod | |||||
def train_step(self, batch: torch.Tensor) -> float: ... | |||||
@abstractmethod | |||||
def eval_step(self, batch: torch.Tensor) -> float: ... |
import torch | |||||
import torch.nn.functional as F | |||||
from einops import repeat | |||||
from pykeen.sampling import BernoulliNegativeSampler | |||||
from pykeen.training import SLCWATrainingLoop | |||||
from pykeen.triples.instances import SLCWABatch | |||||
from torch.optim import Adam | |||||
from torch.optim.lr_scheduler import StepLR | |||||
from torch.utils.data import DataLoader | |||||
from metrics.ranking import simple_ranking | |||||
from models.translation.trans_e import TransE | |||||
from tools.params import TrainingParams | |||||
from training.model_trainers.model_trainer_base import ModelTrainerBase | |||||
class TransETrainer(ModelTrainerBase): | |||||
def __init__( | |||||
self, | |||||
model: TransE, | |||||
training_params: TrainingParams, | |||||
triples_factory: torch.Tensor, | |||||
mapped_triples: torch.Tensor, | |||||
device: torch.device, | |||||
margin: float = 1.0, | |||||
n_negative: int = 1, | |||||
loss_fn: str = "margin", | |||||
**kwargs, | |||||
) -> None: | |||||
optimizer = Adam( | |||||
model.inner_model.parameters(), | |||||
lr=training_params.lr, | |||||
weight_decay=training_params.weight_decay, | |||||
) | |||||
# scheduler = None # Placeholder for scheduler, can be implemented later | |||||
scheduler = StepLR( | |||||
optimizer, | |||||
step_size=training_params.lr_step, | |||||
gamma=training_params.gamma, | |||||
) | |||||
super().__init__(model, optimizer, scheduler, device) | |||||
self.triples_factory = triples_factory | |||||
self.margin = margin | |||||
self.n_negative = n_negative | |||||
self.loss_fn = loss_fn | |||||
self.negative_sampler = BernoulliNegativeSampler( | |||||
mapped_triples=mapped_triples.to(device), | |||||
num_entities=model.num_entities, | |||||
num_relations=model.num_relations, | |||||
num_negs_per_pos=n_negative, | |||||
).to(device) | |||||
self.training_loop = SLCWATrainingLoop( | |||||
model=model.inner_model, | |||||
triples_factory=triples_factory, | |||||
optimizer=optimizer, | |||||
negative_sampler=self.negative_sampler, | |||||
lr_scheduler=scheduler, | |||||
) | |||||
def create_data_loader( | |||||
self, | |||||
triples_factory: torch.Tensor, | |||||
batch_size: int, | |||||
shuffle: bool = True, | |||||
) -> DataLoader: | |||||
triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||||
return self.training_loop._create_training_data_loader( | |||||
triples_factory=triples_factory, | |||||
sampler=None, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False, | |||||
) | |||||
def _loss(self, pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor: | |||||
""" | |||||
Compute the loss based on the positive and negative scores. | |||||
Args: | |||||
pos_scores (torch.Tensor): Scores for positive triples. | |||||
neg_scores (torch.Tensor): Scores for negative triples. | |||||
Returns: | |||||
torch.Tensor: The computed loss. | |||||
""" | |||||
k = neg_scores.shape[1] | |||||
if self.loss_fn == "margin": | |||||
target = -torch.ones_like(neg_scores) # want s_pos < s_neg | |||||
loss = F.margin_ranking_loss( | |||||
repeat(pos_scores, "b -> b k", k=k), | |||||
neg_scores, | |||||
target, | |||||
margin=self.margin, | |||||
) | |||||
return loss | |||||
if self.loss_fn == "adversarial": | |||||
# b = pos_scores.size(0) | |||||
# neg_scores = rearrange(neg_scores, "(b k) -> b k", b=b, k=k) | |||||
# importance weights (detach so gradients flow only through log-sigmoid terms) | |||||
w = F.softmax(self.adv_temp * neg_scores, dim=1).detach() | |||||
pos_term = -F.logsigmoid(self.margin - pos_scores) # (B,) | |||||
neg_term = -(w * F.logsigmoid(neg_scores - self.margin)).sum(dim=1) # (B,) | |||||
loss = (pos_term + neg_term).mean() | |||||
return loss | |||||
return None | |||||
def train_step(self, batch: SLCWABatch, step: int) -> float: | |||||
""" | |||||
Perform a training step on the given batch of data. | |||||
Args: | |||||
batch (torch.Tensor): A tensor containing a batch of triples. | |||||
Returns: | |||||
float: The computed loss for the batch. | |||||
""" | |||||
continue_training = step > 0 | |||||
loss = self.training_loop.train( | |||||
triples_factory=self.triples_factory, | |||||
num_epochs=step + 1, | |||||
batch_size=self.training_loop._get_batch_size(batch), | |||||
continue_training=continue_training, | |||||
use_tqdm=False, | |||||
use_tqdm_batch=False, | |||||
label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||||
# num_workers=3, | |||||
)[-1] | |||||
return loss | |||||
def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||||
"""Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||||
self.model.eval() | |||||
with torch.no_grad(): | |||||
return simple_ranking(self.model, batch.to(self.device)) |
import torch | |||||
from pykeen.losses import MarginRankingLoss | |||||
from pykeen.sampling import BernoulliNegativeSampler | |||||
from pykeen.training import SLCWATrainingLoop | |||||
from pykeen.triples.instances import SLCWABatch | |||||
from torch.optim import Adam | |||||
from metrics.ranking import simple_ranking | |||||
from models.translation.trans_h import TransH | |||||
from tools.params import TrainingParams | |||||
from tools.train import StepMinLR | |||||
from training.model_trainers.model_trainer_base import ModelTrainerBase | |||||
class TransHTrainer(ModelTrainerBase): | |||||
def __init__( | |||||
self, | |||||
model: TransH, | |||||
training_params: TrainingParams, | |||||
triples_factory: torch.Tensor, | |||||
mapped_triples: torch.Tensor, | |||||
device: torch.device, | |||||
margin: float = 1.0, | |||||
n_negative: int = 1, | |||||
regul_rate: float = 0.0, | |||||
loss_fn: str = "margin", | |||||
**kwargs, | |||||
) -> None: | |||||
optimizer = Adam( | |||||
model.inner_model.parameters(), | |||||
lr=training_params.lr, | |||||
weight_decay=training_params.weight_decay, | |||||
) | |||||
scheduler = StepMinLR( | |||||
optimizer, | |||||
step_size=training_params.lr_step, | |||||
gamma=training_params.gamma, | |||||
min_lr=training_params.min_lr, | |||||
) | |||||
super().__init__(model, optimizer, scheduler, device) | |||||
self.triples_factory = triples_factory | |||||
self.margin = margin | |||||
self.n_negative = n_negative | |||||
self.regul_rate = regul_rate | |||||
self.loss_fn = loss_fn | |||||
self.negative_sampler = BernoulliNegativeSampler( | |||||
mapped_triples=mapped_triples.to(device), | |||||
num_entities=model.num_entities, | |||||
num_relations=model.num_relations, | |||||
num_negs_per_pos=n_negative, | |||||
).to(device) | |||||
self.training_loop = SLCWATrainingLoop( | |||||
model=model.inner_model, | |||||
triples_factory=triples_factory, | |||||
optimizer=optimizer, | |||||
negative_sampler=self.negative_sampler, | |||||
lr_scheduler=scheduler, | |||||
) | |||||
def _loss(self) -> torch.Tensor: | |||||
return self.model.loss | |||||
def create_data_loader( | |||||
self, | |||||
triples_factory: torch.Tensor, | |||||
batch_size: int, | |||||
shuffle: bool = True, | |||||
) -> torch.utils.data.DataLoader: | |||||
triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||||
return self.training_loop._create_training_data_loader( | |||||
triples_factory=triples_factory, | |||||
sampler=None, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False, | |||||
) | |||||
@property | |||||
def loss(self) -> torch.Tensor: | |||||
return MarginRankingLoss(margin=self.margin, reduction="mean") | |||||
def train_step( | |||||
self, | |||||
batch: SLCWABatch, | |||||
step: int, | |||||
) -> float: | |||||
""" | |||||
Perform a training step on the given batch of data. | |||||
Args: | |||||
batch (torch.Tensor): A tensor containing a batch of triples. | |||||
Returns: | |||||
float: The computed loss for the batch. | |||||
""" | |||||
continue_training = step > 0 | |||||
loss = self.training_loop.train( | |||||
triples_factory=self.triples_factory, | |||||
num_epochs=step + 1, | |||||
batch_size=self.training_loop._get_batch_size(batch), | |||||
continue_training=continue_training, | |||||
use_tqdm=False, | |||||
use_tqdm_batch=False, | |||||
label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||||
)[-1] | |||||
return loss | |||||
def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||||
"""Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||||
self.model.eval() | |||||
with torch.no_grad(): | |||||
return simple_ranking(self.model, batch.to(self.device)) |
import torch | |||||
from pykeen.sampling import BernoulliNegativeSampler | |||||
from pykeen.training import SLCWATrainingLoop | |||||
from pykeen.triples.instances import SLCWABatch | |||||
from torch.optim import Adam | |||||
from metrics.ranking import simple_ranking | |||||
from models.translation.trans_r import TransR | |||||
from tools.params import TrainingParams | |||||
from tools.train import StepMinLR | |||||
from training.model_trainers.model_trainer_base import ModelTrainerBase | |||||
class TransRTrainer(ModelTrainerBase): | |||||
def __init__( | |||||
self, | |||||
model: TransR, | |||||
training_params: TrainingParams, | |||||
triples_factory: torch.Tensor, | |||||
mapped_triples: torch.Tensor, | |||||
device: torch.device, | |||||
margin: float = 1.0, | |||||
n_negative: int = 1, | |||||
regul_rate: float = 0.0, | |||||
loss_fn: str = "margin", | |||||
**kwargs, | |||||
) -> None: | |||||
optimizer = Adam( | |||||
model.inner_model.parameters(), | |||||
lr=training_params.lr, | |||||
weight_decay=training_params.weight_decay, | |||||
) | |||||
scheduler = StepMinLR( | |||||
optimizer, | |||||
step_size=training_params.lr_step, | |||||
gamma=training_params.gamma, | |||||
min_lr=training_params.min_lr, | |||||
) | |||||
super().__init__(model, optimizer, scheduler, device) | |||||
self.triples_factory = triples_factory | |||||
self.margin = margin | |||||
self.n_negative = n_negative | |||||
self.regul_rate = regul_rate | |||||
self.loss_fn = loss_fn | |||||
self.negative_sampler = BernoulliNegativeSampler( | |||||
mapped_triples=mapped_triples.to(device), | |||||
num_entities=model.num_entities, | |||||
num_relations=model.num_relations, | |||||
num_negs_per_pos=n_negative, | |||||
).to(device) | |||||
self.training_loop = SLCWATrainingLoop( | |||||
model=model.inner_model, | |||||
triples_factory=triples_factory, | |||||
optimizer=optimizer, | |||||
negative_sampler=self.negative_sampler, | |||||
lr_scheduler=scheduler, | |||||
) | |||||
def create_data_loader( | |||||
self, | |||||
triples_factory: torch.Tensor, | |||||
batch_size: int, | |||||
shuffle: bool = True, | |||||
) -> torch.utils.data.DataLoader: | |||||
triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||||
return self.training_loop._create_training_data_loader( | |||||
triples_factory=triples_factory, | |||||
sampler=None, | |||||
batch_size=batch_size, | |||||
shuffle=shuffle, | |||||
drop_last=False, | |||||
) | |||||
def _loss(self) -> torch.Tensor: | |||||
return self.model.loss | |||||
def train_step( | |||||
self, | |||||
batch: SLCWABatch, | |||||
step: int, | |||||
) -> float: | |||||
""" | |||||
Perform a training step on the given batch of data. | |||||
Args: | |||||
batch (torch.Tensor): A tensor containing a batch of triples. | |||||
Returns: | |||||
float: The computed loss for the batch. | |||||
""" | |||||
continue_training = step > 0 | |||||
loss = self.training_loop.train( | |||||
triples_factory=self.triples_factory, | |||||
num_epochs=step + 1, | |||||
batch_size=self.training_loop._get_batch_size(batch), | |||||
continue_training=continue_training, | |||||
use_tqdm=False, | |||||
use_tqdm_batch=False, | |||||
label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||||
)[-1] | |||||
return loss | |||||
def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||||
"""Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||||
self.model.eval() | |||||
with torch.no_grad(): | |||||
return simple_ranking(self.model, batch.to(self.device)) |
from __future__ import annotations | |||||
from pathlib import Path | |||||
from typing import TYPE_CHECKING, Any | |||||
import torch | |||||
from pykeen.evaluation import RankBasedEvaluator | |||||
from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank | |||||
from torch.utils.data import DataLoader | |||||
from torch.utils.tensorboard import SummaryWriter | |||||
from tools import ( | |||||
CheckpointManager, | |||||
DeterministicRandomSampler, | |||||
TensorBoardHandler, | |||||
get_pretty_logger, | |||||
set_seed, | |||||
) | |||||
if TYPE_CHECKING: | |||||
from data.kg_dataset import KGDataset | |||||
from tools import CommonParams, TrainingParams | |||||
from .model_trainers.model_trainer_base import ModelTrainerBase | |||||
logger = get_pretty_logger(__name__) | |||||
cpu_device = torch.device("cpu") | |||||
class Trainer: | |||||
def __init__( | |||||
self, | |||||
train_dataset: KGDataset, | |||||
val_dataset: KGDataset | None, | |||||
model_trainer: ModelTrainerBase, | |||||
common_params: CommonParams, | |||||
training_params: TrainingParams, | |||||
device: torch.device = cpu_device, | |||||
) -> None: | |||||
set_seed(training_params.seed) | |||||
self.train_dataset = train_dataset | |||||
self.val_dataset = val_dataset if val_dataset is not None else train_dataset | |||||
self.model_trainer = model_trainer | |||||
self.common_params = common_params | |||||
self.training_params = training_params | |||||
self.device = device | |||||
self.train_loader = self.model_trainer.create_data_loader( | |||||
triples_factory=self.train_dataset.triples_factory, | |||||
batch_size=training_params.batch_size, | |||||
shuffle=True, | |||||
) | |||||
self.train_iterator = self._data_iterator(self.train_loader) | |||||
seed = 1234 | |||||
val_sampler = DeterministicRandomSampler( | |||||
len(self.val_dataset), | |||||
sample_size=self.training_params.validation_sample_size, | |||||
seed=seed, | |||||
) | |||||
self.valid_loader = DataLoader( | |||||
self.val_dataset, | |||||
batch_size=training_params.validation_batch_size, | |||||
shuffle=False, | |||||
num_workers=training_params.num_workers, | |||||
sampler=val_sampler, | |||||
) | |||||
self.valid_iterator = self._data_iterator(self.valid_loader) | |||||
self.step = 0 | |||||
self.log_every = common_params.log_every | |||||
self.log_console_every = common_params.log_console_every | |||||
self.save_dpath = common_params.save_dpath | |||||
self.save_every = common_params.save_every | |||||
self.checkpoint_manager = CheckpointManager( | |||||
root_directory=self.save_dpath, | |||||
run_name=common_params.run_name, | |||||
load_only=common_params.evaluate_only, | |||||
) | |||||
self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory | |||||
if not common_params.evaluate_only: | |||||
logger.info(f"Saving checkpoints to {self.ckpt_dpath}") | |||||
if self.common_params.load_path: | |||||
self.load() | |||||
self.writer = None | |||||
if not common_params.evaluate_only: | |||||
run_name = self.checkpoint_manager.run_name | |||||
self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}") | |||||
logger.addHandler(TensorBoardHandler(self.writer)) | |||||
def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None: | |||||
""" | |||||
Log a message to the console and TensorBoard. | |||||
Args: | |||||
message (str): The message to log. | |||||
""" | |||||
if console_msg is not None: | |||||
logger.info(f"{console_msg}") | |||||
if self.writer is not None: | |||||
self.writer.add_scalar(tag, log_value, self.step) | |||||
def save(self) -> None: | |||||
save_dpath = self.ckpt_dpath | |||||
save_dpath = Path(save_dpath) | |||||
save_dpath.mkdir(parents=True, exist_ok=True) | |||||
self.model_trainer.save(save_dpath, self.step) | |||||
def load(self) -> None: | |||||
load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path) | |||||
load_fpath = Path(load_path) | |||||
if load_fpath.exists(): | |||||
self.step = self.model_trainer.load(load_fpath) | |||||
logger.info(f"Model loaded from {load_fpath}.") | |||||
else: | |||||
msg = f"Load path {load_fpath} does not exist." | |||||
raise FileNotFoundError(msg) | |||||
def _data_iterator(self, loader: DataLoader) -> iter: | |||||
"""Yield batches of positive and negative triples.""" | |||||
while True: | |||||
yield from loader | |||||
def train(self) -> None: | |||||
"""Run `n_steps` optimisation steps; return mean loss.""" | |||||
data_batch = next(self.train_iterator) | |||||
loss = self.model_trainer.train_step(data_batch, self.step) | |||||
if self.step % self.log_console_every == 0: | |||||
logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}") | |||||
if self.step % self.log_every == 0: | |||||
self.log("train/loss", loss) | |||||
self.step += 1 | |||||
def evaluate(self) -> dict[str, torch.Tensor]: | |||||
ks = [1, 3, 10] | |||||
metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks] | |||||
evaluator = RankBasedEvaluator( | |||||
metrics=metrics, | |||||
filtered=True, | |||||
batch_size=self.training_params.validation_batch_size, | |||||
) | |||||
result = evaluator.evaluate( | |||||
model=self.model_trainer.model, | |||||
mapped_triples=self.val_dataset.triples, | |||||
additional_filter_triples=[ | |||||
self.train_dataset.triples, | |||||
self.val_dataset.triples, | |||||
], | |||||
batch_size=self.training_params.validation_batch_size, | |||||
) | |||||
mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank") | |||||
hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks} | |||||
metrics = { | |||||
"mrr": mrr, | |||||
**hits_at, | |||||
} | |||||
for k, v in metrics.items(): | |||||
self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}") | |||||
# ───────────────────────── # | |||||
# master schedule | |||||
# ───────────────────────── # | |||||
def run(self) -> None: | |||||
""" | |||||
Alternate training and evaluation until `total_steps` | |||||
optimisation steps have been executed. | |||||
""" | |||||
total_steps = self.training_params.num_train_steps | |||||
eval_every = self.training_params.eval_every | |||||
while self.step < total_steps: | |||||
self.train() | |||||
if self.step % eval_every == 0: | |||||
logger.info(f"Evaluating at step {self.step}...") | |||||
self.evaluate() | |||||
if self.step % self.save_every == 0: | |||||
logger.info(f"Saving model at step {self.step}...") | |||||
self.save() | |||||
logger.info("Training complete.") | |||||
def reset(self) -> None: | |||||
""" | |||||
Reset the trainer state, including the model trainer and data iterators. | |||||
""" | |||||
self.model_trainer.reset() | |||||
self.train_iterator = self._data_iterator(self.train_loader) | |||||
self.valid_iterator = self._data_iterator(self.valid_loader) | |||||
self.step = 0 | |||||
logger.info("Trainer state has been reset.") |