@@ -0,0 +1 @@ | |||
# KGEvaluation |
@@ -0,0 +1,106 @@ | |||
from __future__ import annotations | |||
import argparse | |||
import logging | |||
import torch | |||
from data import YAGO310Dataset | |||
# from metrics.crec_modifier import tune_crec_parallel_edits | |||
from data.kg_dataset import KGDataset | |||
from metrics.crec_radius_sample import find_sample | |||
from metrics.wlcrec import WLCREC | |||
from tools import get_pretty_logger | |||
logging = get_pretty_logger(__name__) | |||
# -------------------------------------------------------------------- | |||
# CLI | |||
# -------------------------------------------------------------------- | |||
def main(): | |||
p = argparse.ArgumentParser() | |||
p.add_argument("--depth", type=int, default=5) | |||
p.add_argument("--lower", type=float, default=0.25) | |||
p.add_argument("--upper", type=float, default=0.26) | |||
p.add_argument("--seed", type=int, default=0) | |||
p.add_argument("--save", type=str, default="fb15k_crec_bi.pt") | |||
args = p.parse_args() | |||
# ds = FB15KDataset() | |||
ds = YAGO310Dataset() | |||
triples = ds.triples | |||
n_ent, n_rel = ds.num_entities, ds.num_relations | |||
logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples)) | |||
# wlcrec = WLCREC(ds) | |||
# c = wlcrec.compute(5)[-2] # C_NWLEC | |||
# colours = wlcrec.wl_colours(args.depth)[-1] | |||
# tuned = tune_crec( | |||
# wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed | |||
# ) | |||
# tuned = tune_crec_parallel_edits( | |||
# triples, | |||
# n_ent, | |||
# n_rel, | |||
# depth=args.depth, | |||
# lo=args.lower, | |||
# hi=args.upper, | |||
# seed=args.seed, | |||
# ) | |||
# tuned = fast_tune_crec( | |||
# triples.tolist(), | |||
# colours, | |||
# n_ent, | |||
# n_rel, | |||
# depth=args.depth, | |||
# c=c, | |||
# lo=args.lower, | |||
# hi=args.upper, | |||
# max_iters=1000, | |||
# max_workers=10, # Adjust number of workers as needed | |||
# ) | |||
tuned = find_sample( | |||
triples.tolist(), | |||
n_ent, | |||
n_rel, | |||
depth=args.depth, | |||
ratio=0.1, # Adjust ratio as needed | |||
radius=2, # Adjust radius as needed | |||
lower=args.lower, | |||
upper=args.upper, | |||
max_tries=1000, | |||
seed=58, | |||
) | |||
tuned_ds = KGDataset( | |||
triples_factory=None, | |||
triples=torch.tensor(tuned, dtype=torch.long), | |||
num_entities=n_ent, | |||
num_relations=n_rel, | |||
split="train", | |||
) | |||
tuned_wlcrec = WLCREC(tuned_ds) | |||
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5) | |||
print( | |||
f"\nTuned CREC results (H={5})", | |||
f"\n • avg WLEC : {entropy:.6f} nats", | |||
f"\n • C_ratio : {c_ratio:.6f}", | |||
f"\n • C_NWLEC : {c_nwlec:.6f}", | |||
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||
f"\n • D_ratio : {d_ratio:.6f}", | |||
f"\n • D_NWLEC : {d_nwlec:.6f}", | |||
) | |||
torch.save(torch.tensor(tuned, dtype=torch.long), args.save) | |||
logging.info("Saved tuned triples → %s", args.save) | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,20 @@ | |||
_target_: tools.params.CommonParams | |||
model_name: "default_model" | |||
project_name: "my_project" | |||
run_name: ... | |||
# Where to write all outputs (checkpoints / tensorboard logs / etc.) | |||
save_dpath: "./checkpoints" | |||
save_every: 200 | |||
# If you want to resume from a checkpoint, set load_dpath to a string path; | |||
# otherwise leave it null or empty. | |||
load_path: null | |||
# Default log directory | |||
log_dir: "logs" | |||
log_every: 10 | |||
log_console_every: 10 | |||
evaluate_only: false |
@@ -0,0 +1,9 @@ | |||
defaults: | |||
- common: common | |||
- model: model | |||
- data: data | |||
- training: training | |||
hydra: | |||
run: | |||
dir: "./outputs" |
@@ -0,0 +1,3 @@ | |||
_target_: data.wn18rr_dataset.WN18RRDataset | |||
split: train |
@@ -0,0 +1,8 @@ | |||
train: | |||
_target_: data.fb15k.FB15KDataset | |||
split: train | |||
valid: | |||
_target_: data.fb15k.FB15KDataset | |||
split: valid |
@@ -0,0 +1,8 @@ | |||
train: | |||
_target_: data.wn18.WN18Dataset | |||
split: train | |||
valid: | |||
_target_: data.wn18.WN18Dataset | |||
split: valid |
@@ -0,0 +1,8 @@ | |||
train: | |||
_target_: data.wn18.WN18RRDataset | |||
split: train | |||
valid: | |||
_target_: data.wn18.WN18RRDataset | |||
split: valid |
@@ -0,0 +1,9 @@ | |||
train: | |||
_target_: data.yago3_10.YAGO310Dataset | |||
split: train | |||
valid: | |||
_target_: data.yago3_10.YAGO310Dataset | |||
split: valid | |||
@@ -0,0 +1,3 @@ | |||
num_entities: 100 | |||
num_relations: 100 | |||
dim: 100 |
@@ -0,0 +1,9 @@ | |||
defaults: | |||
- model | |||
- _self_ | |||
_target_: models.translation.trans_e.TransE | |||
dim: 200 | |||
sf_norm: 1 | |||
p_norm: true |
@@ -0,0 +1,11 @@ | |||
defaults: | |||
- model | |||
- _self_ | |||
_target_: models.translation.trans_h.TransH | |||
dim: 200 | |||
sf_norm: 1 | |||
p_norm: true | |||
p_norm_value: 2 | |||
margin: 1.0 |
@@ -0,0 +1,12 @@ | |||
defaults: | |||
- model | |||
- _self_ | |||
_target_: models.translation.trans_r.TransR | |||
entity_dim: 200 | |||
relation_dim: 30 | |||
sf_norm: 1 | |||
p_norm: true | |||
p_norm_value: 2 | |||
margin: 1.0 |
@@ -0,0 +1,21 @@ | |||
_target_: tools.params.TrainingParams | |||
lr: 0.005 | |||
min_lr: 0.0001 | |||
weight_decay: 0.0 | |||
t0: 50 | |||
lr_step: 10 | |||
gamma: 0.95 | |||
batch_size: 8192 | |||
num_workers: 3 | |||
seed: 42 | |||
num_train_steps: 10000 | |||
eval_every: 100 | |||
validation_sample_size: 1000 | |||
validation_batch_size: 64 | |||
model_trainer: | |||
_target_: training.model_trainers.model_trainer_base.ModelTrainerBase |
@@ -0,0 +1,11 @@ | |||
defaults: | |||
- training | |||
- _self_ | |||
model_trainer: | |||
_target_: training.model_trainers.translation.trans_e_trainer.TransETrainer | |||
margin: 1.0 | |||
n_negative: 5 | |||
regul_rate: 0.0 | |||
loss_fn: "margin" |
@@ -0,0 +1,11 @@ | |||
defaults: | |||
- training | |||
- _self_ | |||
model_trainer: | |||
_target_: training.model_trainers.translation.trans_r_trainer.TransRTrainer | |||
margin: 1.0 | |||
n_negative: 10 | |||
regul_rate: 0.0 | |||
loss_fn: "margin" |
@@ -0,0 +1,17 @@ | |||
defaults: | |||
- config | |||
- override common: common | |||
- override data: fb15k | |||
- override model: trans_e | |||
- override training: trans_e_trainer | |||
- _self_ | |||
training: | |||
model_trainer: | |||
loss_fn: "margin" | |||
common: | |||
run_name: "TransE_FB15K" | |||
evaluate_only: true | |||
load_path: (2, 1800) |
@@ -0,0 +1,17 @@ | |||
defaults: | |||
- config | |||
- override common: common | |||
- override data: wn18 | |||
- override model: trans_e | |||
- override training: trans_e_trainer | |||
- _self_ | |||
training: | |||
model_trainer: | |||
loss_fn: "margin" | |||
common: | |||
run_name: "TransE_WN18" | |||
evaluate_only: true | |||
load_path: (82, 3000) |
@@ -0,0 +1,18 @@ | |||
defaults: | |||
- config | |||
- override common: common | |||
- override data: wn18rr | |||
- override model: trans_e | |||
- override training: trans_e_trainer | |||
- _self_ | |||
training: | |||
model_trainer: | |||
loss_fn: "margin" | |||
common: | |||
run_name: "TransE_WN18rr" | |||
evaluate_only: true | |||
load_path: (2, 6000) |
@@ -0,0 +1,17 @@ | |||
defaults: | |||
- config | |||
- override common: common | |||
- override data: yago3_10 | |||
- override model: trans_e | |||
- override training: trans_e_trainer | |||
- _self_ | |||
training: | |||
model_trainer: | |||
loss_fn: "margin" | |||
common: | |||
run_name: "TransE_YAGO310" | |||
evaluate_only: true | |||
load_path: (1, 2600) |
@@ -0,0 +1,17 @@ | |||
defaults: | |||
- config | |||
- override common: common | |||
- override data: wn18 | |||
- override model: trans_r | |||
- override training: trans_r_trainer | |||
- _self_ | |||
training: | |||
model_trainer: | |||
loss_fn: "margin" | |||
common: | |||
run_name: "TransR_WN18" | |||
# evaluate_only: true | |||
# load_path: (1, 1000) |
@@ -0,0 +1,6 @@ | |||
from .fb15k import FB15KDataset | |||
from .wn18 import WN18Dataset, WN18RRDataset | |||
from .yago3_10 import YAGO310Dataset | |||
from .hationet import HetionetDataset | |||
from .open_bio_link import OpenBioLinkDataset | |||
from .openke_wiki import OpenKEWikiDataset |
@@ -0,0 +1,49 @@ | |||
from typing import List, Tuple | |||
import torch | |||
from pykeen.datasets import FB15k | |||
from .kg_dataset import KGDataset | |||
class FB15KDataset(KGDataset): | |||
"""A wrapper around PyKeen's FB15k dataset producing KGDataset instances.""" | |||
def __init__(self, split: str = "train") -> None: | |||
dataset = FB15k() | |||
# if split == "train": | |||
if True: | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'" | |||
raise ValueError(msg) | |||
custom_triples = torch.load( | |||
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt", | |||
map_location=torch.device("cpu"), | |||
).to(torch.long) | |||
# .tolist() | |||
tf = tf.clone_and_exchange_triples( | |||
custom_triples, | |||
) | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=dataset.num_entities, | |||
num_relations=dataset.num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
return torch.tensor(super().__getitem__(idx), dtype=torch.long) |
@@ -0,0 +1,57 @@ | |||
from typing import List, Tuple | |||
import torch | |||
from pykeen.datasets import Hetionet | |||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
class HetionetDataset(KGDataset): | |||
"""A wrapper around PyKeen's Hetionet dataset that produces KGDataset instances.""" | |||
def __init__(self, split: str = "train") -> None: | |||
""" | |||
Initialize the WN18RR wrapper. | |||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||
""" | |||
# Load the PyKeen WN18 dataset | |||
dataset = Hetionet() | |||
# Select the appropriate TriplesFactory based on the requested split | |||
if split == "train": | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
raise ValueError( | |||
msg, | |||
) | |||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
# Convert each row into a Python tuple of ints | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
# Get the total number of entities and relations in the entire dataset | |||
num_entities = dataset.num_entities | |||
num_relations = dataset.num_relations | |||
# Initialize the base KGDataset with the selected triples | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=num_entities, | |||
num_relations=num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
data = super().__getitem__(idx) | |||
# Convert the tuple to a torch tensor | |||
return torch.tensor(data, dtype=torch.long) |
@@ -0,0 +1,89 @@ | |||
"""PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||
from pathlib import Path | |||
from typing import Dict, List, Tuple | |||
import torch | |||
from pykeen.typing import MappedTriples | |||
from torch.utils.data import Dataset | |||
class KGDataset(Dataset): | |||
"""PyTorch Dataset wrapping indexed triples (h, r, t).""" | |||
def __init__( | |||
self, | |||
triples_factory: MappedTriples, | |||
triples: List[Tuple[int, int, int]], | |||
num_entities: int, | |||
num_relations: int, | |||
split: str = "train", | |||
) -> None: | |||
super().__init__() | |||
self.triples_factory = triples_factory | |||
self.triples = torch.tensor(triples, dtype=torch.long) | |||
self.num_entities = num_entities | |||
self.num_relations = num_relations | |||
self.split = split | |||
def __len__(self) -> int: | |||
"""Return the number of triples in the dataset.""" | |||
return len(self.triples) | |||
def __getitem__(self, idx: int) -> Tuple[int, int, int]: | |||
"""Return the indexed triple at the given index.""" | |||
return self.triples[idx] | |||
def entities(self) -> torch.Tensor: | |||
"""Return a list of all entities in the dataset.""" | |||
return torch.arange(self.num_entities) | |||
def relations(self) -> torch.Tensor: | |||
"""Return a list of all relations in the dataset.""" | |||
return torch.arange(self.num_relations) | |||
def __repr__(self) -> str: | |||
return ( | |||
f"{self.__class__.__name__}(split={self.split!r}, " | |||
f"num_triples={len(self.triples)}, " | |||
f"num_entities={self.num_entities}, " | |||
f"num_relations={self.num_relations})" | |||
) | |||
# ---------------------------- Helper functions -------------------------------- | |||
def _map_to_ids( | |||
triples: List[Tuple[str, str, str]], | |||
) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]: | |||
ent2id: Dict[str, int] = {} | |||
rel2id: Dict[str, int] = {} | |||
mapped: List[Tuple[int, int, int]] = [] | |||
def _id(d: Dict[str, int], k: str) -> int: | |||
d.setdefault(k, len(d)) | |||
return d[k] | |||
for h, r, t in triples: | |||
mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t))) | |||
return mapped, ent2id, rel2id | |||
def load_kg_from_files( | |||
root: str, splits: tuple[str] = ("train", "valid", "test") | |||
) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]: | |||
root = Path(root) | |||
split_raw, all_raw = {}, [] | |||
for s in splits: | |||
lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()] | |||
split_raw[s] = lines | |||
all_raw.extend(lines) | |||
mapped, ent2id, rel2id = _map_to_ids(all_raw) | |||
datasets, start = {}, 0 | |||
for s in splits: | |||
end = start + len(split_raw[s]) | |||
datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s) | |||
start = end | |||
return datasets, ent2id, rel2id |
@@ -0,0 +1,57 @@ | |||
from typing import List, Tuple | |||
import torch | |||
from pykeen.datasets import OpenBioLink | |||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
class OpenBioLinkDataset(KGDataset): | |||
"""A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances.""" | |||
def __init__(self, split: str = "train") -> None: | |||
""" | |||
Initialize the WN18RR wrapper. | |||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||
""" | |||
# Load the PyKeen WN18 dataset | |||
dataset = OpenBioLink() | |||
# Select the appropriate TriplesFactory based on the requested split | |||
if split == "train": | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
raise ValueError( | |||
msg, | |||
) | |||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
# Convert each row into a Python tuple of ints | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
# Get the total number of entities and relations in the entire dataset | |||
num_entities = dataset.num_entities | |||
num_relations = dataset.num_relations | |||
# Initialize the base KGDataset with the selected triples | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=num_entities, | |||
num_relations=num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
data = super().__getitem__(idx) | |||
# Convert the tuple to a torch tensor | |||
return torch.tensor(data, dtype=torch.long) |
@@ -0,0 +1,57 @@ | |||
from typing import List, Tuple | |||
import torch | |||
from pykeen.datasets import Wikidata5M | |||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
class OpenKEWikiDataset(KGDataset): | |||
"""A wrapper around PyKeen's OpenKE-Wiki (WikiN3) dataset that produces KGDataset instances.""" | |||
def __init__(self, split: str = "train") -> None: | |||
""" | |||
Initialize the WN18RR wrapper. | |||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||
""" | |||
# Load the PyKeen WN18 dataset | |||
dataset = Wikidata5M() | |||
# Select the appropriate TriplesFactory based on the requested split | |||
if split == "train": | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
raise ValueError( | |||
msg, | |||
) | |||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
# Convert each row into a Python tuple of ints | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
# Get the total number of entities and relations in the entire dataset | |||
num_entities = dataset.num_entities | |||
num_relations = dataset.num_relations | |||
# Initialize the base KGDataset with the selected triples | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=num_entities, | |||
num_relations=num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
data = super().__getitem__(idx) | |||
# Convert the tuple to a torch tensor | |||
return torch.tensor(data, dtype=torch.long) |
@@ -0,0 +1,112 @@ | |||
from typing import List, Tuple | |||
import torch | |||
from pykeen.datasets import WN18, WN18RR | |||
from .kg_dataset import KGDataset # Importing KGDataset from the same module | |||
class WN18Dataset(KGDataset): | |||
"""A wrapper around PyKeen's WN18 dataset that produces KGDataset instances.""" | |||
def __init__(self, split: str = "train") -> None: | |||
""" | |||
Initialize the WN18RR wrapper. | |||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||
""" | |||
# Load the PyKeen WN18 dataset | |||
dataset = WN18() | |||
# Select the appropriate TriplesFactory based on the requested split | |||
if split == "train": | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
raise ValueError( | |||
msg, | |||
) | |||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
# Convert each row into a Python tuple of ints | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
# Get the total number of entities and relations in the entire dataset | |||
num_entities = dataset.num_entities | |||
num_relations = dataset.num_relations | |||
# Initialize the base KGDataset with the selected triples | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=num_entities, | |||
num_relations=num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
data = super().__getitem__(idx) | |||
# Convert the tuple to a torch tensor | |||
return torch.tensor(data, dtype=torch.long) | |||
# Assuming KGDataset is defined in the same module or already imported: | |||
# from your_module import KGDataset | |||
class WN18RRDataset(KGDataset): | |||
"""A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances.""" | |||
def __init__(self, split: str = "train") -> None: | |||
""" | |||
Initialize the WN18RR wrapper. | |||
:param split: One of "train", "valid", or "test" indicating which split to load. | |||
""" | |||
# Load the PyKeen WN18RR dataset | |||
dataset = WN18RR() | |||
# Select the appropriate TriplesFactory based on the requested split | |||
if split == "train": | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." | |||
raise ValueError( | |||
msg, | |||
) | |||
# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 | |||
# Convert each row into a Python tuple of ints | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
# Get the total number of entities and relations in the entire dataset | |||
num_entities = dataset.num_entities | |||
num_relations = dataset.num_relations | |||
# Initialize the base KGDataset with the selected triples | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=num_entities, | |||
num_relations=num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
data = super().__getitem__(idx) | |||
# Convert the tuple to a torch tensor | |||
return torch.tensor(data, dtype=torch.long) |
@@ -0,0 +1,70 @@ | |||
from typing import List, Tuple | |||
import torch | |||
from pykeen.datasets import PathDataset | |||
from .kg_dataset import KGDataset | |||
class YAGO310Dataset(KGDataset): | |||
"""Wrapper around PyKeen's YAGO3-10 dataset.""" | |||
def __init__(self, split: str = "train") -> None: | |||
# dataset = YAGO310() | |||
dataset = PathDataset( | |||
training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt", | |||
testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt", | |||
validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt", | |||
) | |||
if split == "train": | |||
# if True: | |||
tf = dataset.training | |||
elif split in ("valid", "val", "validation"): | |||
tf = dataset.validation | |||
elif split in ("test", "testing"): | |||
tf = dataset.testing | |||
else: | |||
msg = f"Unrecognized split '{split}'" | |||
raise ValueError(msg) | |||
custom_triples = torch.load( | |||
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt", | |||
map_location=torch.device("cpu"), | |||
).to(torch.long) | |||
# .tolist() | |||
# get a random sample of size 5000 | |||
# seed = torch.seed() # For reproducibility | |||
# torch.manual_seed(seed) | |||
# torch.cuda.manual_seed(seed) | |||
# if custom_triples.shape[0] > 5000: | |||
# custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]] | |||
tf = tf.clone_and_exchange_triples( | |||
custom_triples, | |||
) | |||
# triples = tf.mapped_triples | |||
# # get a random sample of size 5000 | |||
# if triples.shape[0] > 5000: | |||
# indices = torch.randperm(triples.shape[0])[:5000] | |||
# triples = triples[indices] | |||
# tf = tf.clone_and_exchange_triples(triples) | |||
triples_list: List[Tuple[int, int, int]] = [ | |||
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples | |||
] | |||
super().__init__( | |||
triples_factory=tf, | |||
triples=triples_list, | |||
num_entities=dataset.num_entities, | |||
num_relations=dataset.num_relations, | |||
split=split, | |||
) | |||
def __getitem__(self, idx: int) -> torch.Tensor: | |||
return torch.tensor(super().__getitem__(idx), dtype=torch.long) |
@@ -0,0 +1,38 @@ | |||
from data import ( | |||
FB15KDataset, | |||
WN18Dataset, | |||
WN18RRDataset, | |||
YAGO310Dataset, | |||
) | |||
from metrics.wlcrec import WLCREC | |||
def main() -> None: | |||
datasets = { | |||
"YAGO3-10": YAGO310Dataset(split="train"), | |||
# "WN18": WN18Dataset(split="train"), | |||
# "WN18RR": WN18RRDataset(split="train"), | |||
# "FB15K": FB15KDataset(split="train"), | |||
# "Hetionet": HetionetDataset(split="train"), | |||
# "OpenBioLink": OpenBioLinkDataset(split="train"), | |||
# "OpenKEWiki": OpenKEWikiDataset(split="train"), | |||
} | |||
for name, dataset in datasets.items(): | |||
wl_crec = WLCREC(dataset) | |||
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = wl_crec.compute(H=20) | |||
print( | |||
f"\nDataset: {name}", | |||
f"\nResults (H={20})", | |||
f"\n • avg WLEC : {entropy:.6f} nats", | |||
f"\n • C_ratio : {c_ratio:.6f}", | |||
f"\n • C_NWLEC : {c_nwlec:.6f}", | |||
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats", | |||
f"\n • D_ratio : {d_ratio:.6f}", | |||
f"\n • D_NWLEC : {d_nwlec:.6f}", | |||
sep="", | |||
) | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,80 @@ | |||
from copy import deepcopy | |||
import hydra | |||
import lovely_tensors as lt | |||
import torch | |||
from hydra.utils import instantiate | |||
from omegaconf import DictConfig | |||
from metrics.c_swklf import CSWKLF | |||
from tools import get_pretty_logger | |||
# (Import whatever Trainer or runner you have) | |||
from training.trainer import Trainer | |||
logger = get_pretty_logger(__name__) | |||
lt.monkey_patch() | |||
@hydra.main(config_path="configs", config_name="config") | |||
def main(cfg: DictConfig) -> None: | |||
# Detect CUDA | |||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
common_params = cfg.common | |||
training_params = deepcopy(cfg.training) | |||
# drop the model_trainer from cfg.trainer | |||
del training_params.model_trainer | |||
common_params = instantiate(common_params) | |||
training_params = instantiate(training_params) | |||
# dataset instantiation | |||
train_dataset = instantiate(cfg.data.train) | |||
val_dataset = instantiate(cfg.data.valid) | |||
model = instantiate( | |||
cfg.model, | |||
num_entities=train_dataset.num_entities, | |||
num_relations=train_dataset.num_relations, | |||
triples_factory=train_dataset.triples_factory, | |||
device=device, | |||
) | |||
model = model.to(device) | |||
model_trainer = instantiate( | |||
cfg.training.model_trainer, | |||
model=model, | |||
training_params=training_params, | |||
triples_factory=train_dataset.triples_factory, | |||
mapped_triples=train_dataset.triples, | |||
device=model.device, | |||
) | |||
# Instantiate the Trainer with the model_trainer | |||
trainer = Trainer( | |||
train_dataset=train_dataset, | |||
val_dataset=val_dataset, | |||
model_trainer=model_trainer, | |||
common_params=common_params, | |||
training_params=training_params, | |||
device=model.device, | |||
) | |||
if cfg.common.evaluate_only: | |||
trainer.evaluate() | |||
c_swklf_metric = CSWKLF( | |||
dataset=val_dataset, | |||
model=model, | |||
) | |||
c_swklf_value = c_swklf_metric.compute() | |||
logger.info(f"CSWKLF Metric Value: {c_swklf_value}") | |||
else: | |||
trainer.run() | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,55 @@ | |||
from __future__ import annotations | |||
from abc import ABC, abstractmethod | |||
from typing import TYPE_CHECKING, Any, List, Tuple | |||
if TYPE_CHECKING: | |||
from data.kg_dataset import KGDataset | |||
class BaseMetric(ABC): | |||
"""Base class for metrics that compute Weisfeiler-Lehman Entropy | |||
Complexity (WLEC) or similar metrics. | |||
""" | |||
def __init__(self, dataset: KGDataset) -> None: | |||
self.dataset = dataset | |||
self.num_entities = dataset.num_entities | |||
self.out_adj, self.in_adj = self._build_adjacency_lists() | |||
def _build_adjacency_lists( | |||
self, | |||
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]: | |||
"""Create *in* and *out* adjacency lists from mapped triples. | |||
Parameters | |||
---------- | |||
triples: | |||
Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``). | |||
num_entities: | |||
Total number of entities. Needed to allocate adjacency containers. | |||
Returns | |||
------- | |||
out_adj, in_adj | |||
Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element | |||
``in_adj[v]`` is a list ``[(r, h), ...]``. | |||
""" | |||
triples = self.dataset.triples # (m, 3) | |||
num_entities = self.num_entities | |||
out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||
in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)] | |||
for h, r, t in triples.tolist(): | |||
out_adj[h].append((r, t)) | |||
in_adj[t].append((r, h)) | |||
return out_adj, in_adj | |||
@abstractmethod | |||
def compute( | |||
self, | |||
*args, | |||
**kwargs, | |||
) -> Any: ... |
@@ -0,0 +1,264 @@ | |||
from __future__ import annotations | |||
import math | |||
from collections import defaultdict | |||
from dataclasses import dataclass | |||
from typing import TYPE_CHECKING, Callable, Iterable | |||
import torch | |||
from torch import Tensor | |||
from .base_metric import BaseMetric | |||
if TYPE_CHECKING: | |||
from data.kg_dataset import KGDataset | |||
from models.base_model import ERModel | |||
@torch.no_grad() | |||
def _log_prob_distribution( | |||
fn: Callable[[Tensor], Tensor], | |||
query: Tensor, | |||
batch_size: int, | |||
) -> Iterable[Tensor]: | |||
""" | |||
Yield *log*-probability batches for an arbitrary distribution helper. | |||
Parameters | |||
---------- | |||
fn : | |||
A bound method of *model* returning **log**-probabilities, e.g. | |||
``model.tail_distribution(..., log=True)``. | |||
query : | |||
LongTensor of shape (N, 2) containing the query IDs that `fn` expects. | |||
batch_size : | |||
Mini-batch size. Choose according to available GPU/CPU memory. | |||
""" | |||
for i in range(0, len(query), batch_size): | |||
yield fn(query[i : i + batch_size]) # (B, |Candidates|) | |||
def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor: | |||
""" | |||
KL(q || p) where q is uniform over *true_index_sets*. | |||
Parameters | |||
---------- | |||
log_p : | |||
Tensor of shape (B, N) — log-probabilities from the model. | |||
true_index_sets : | |||
List (length B). Element *b* contains the list of **indices** | |||
that are correct for row *b*. | |||
Returns | |||
------- | |||
kl : Tensor, shape (B,) — KL divergence per row (natural log). | |||
""" | |||
rows: list[Tensor] = [] | |||
for lp_row, idx in zip(log_p, true_index_sets): | |||
k = len(idx) | |||
rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p] | |||
return torch.tensor(rows, device=log_p.device) | |||
@dataclass(slots=True) | |||
class _Accum: | |||
"""Simple accumulator for ∑ value and ∑ weight.""" | |||
tot_v: float = 0.0 | |||
tot_w: float = 0.0 | |||
def update(self, value: float, weight: float) -> None: | |||
self.tot_v += value * weight | |||
self.tot_w += weight | |||
@property | |||
def mean(self) -> float: | |||
return self.tot_v / self.tot_w if self.tot_w else float("nan") | |||
class CSWKLF(BaseMetric): | |||
""" | |||
C-SWKLF metric for evaluating knowledge graph embeddings. | |||
This metric is used to evaluate the quality of knowledge graph embeddings | |||
by computing the C-SWKLF score. | |||
""" | |||
def __init__(self, dataset: KGDataset, model: ERModel) -> None: | |||
super().__init__(dataset) | |||
self.model = model | |||
@torch.no_grad() | |||
def compute( | |||
self, | |||
alpha: float = 1 / 3, | |||
beta: float = 1 / 3, | |||
gamma: float = 1 / 3, | |||
batch_size: int = 1024, | |||
slice_size: int | None = 2048, | |||
device: torch.device | str | None = None, | |||
) -> float: | |||
""" | |||
Compute *Comprehensive Structural-Weighted KL-Fitness*. | |||
Parameters | |||
---------- | |||
model : | |||
A trained PyKEEN ERModel **with** the three distribution helpers. | |||
kg : | |||
`pykeen.triples.KGInfo` instance providing `mapped_triples`. | |||
alpha, beta, gamma : | |||
Non-negative weights for the three query types, default 1/3 each. | |||
Must sum to *1*. | |||
batch_size : | |||
Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM. | |||
slice_size : | |||
Forwarded to `tail_distribution` / `head_distribution` / | |||
`relation_distribution`. `None` disables slicing. | |||
device : | |||
Where to do the computation. `None` ⇒ first param's device. | |||
Returns | |||
------- | |||
score : float in (0, 1] | |||
Higher = model closer to empirical distributions across *all* directions. | |||
""" | |||
# --------------------------------------------------------------------- # | |||
# Preparation # | |||
# --------------------------------------------------------------------- # | |||
assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1." | |||
model_device = next(self.model.parameters()).device | |||
if device is None: | |||
device = model_device | |||
triples = self.dataset.triples.to(device) # (|T|, 3) | |||
heads, rels, tails = triples.t() # (|T|,) | |||
# Build index structures -------------------------------------------------- | |||
tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list) | |||
heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list) | |||
rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list) | |||
for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()): | |||
tails_by_hr[(h, r)].append(t) | |||
heads_by_rt[(r, t)].append(h) | |||
rels_by_ht[(h, t)].append(r) | |||
# Structural entropies & query lists --------------------------------------- | |||
H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r | |||
H_head: dict[int, _Accum] = defaultdict(_Accum) | |||
tail_queries: list[tuple[int, int]] = [] # (h,r) | |||
head_queries: list[tuple[int, int]] = [] # (r,t) | |||
rel_queries: list[tuple[int, int]] = [] # (h,t) | |||
# --- tails -------------------------------------------------------------- | |||
for (h, r), ts in tails_by_hr.items(): | |||
k = len(ts) | |||
H_tail[r].update(math.log(k), 1.0) | |||
tail_queries.append((h, r)) | |||
# --- heads -------------------------------------------------------------- | |||
for (r, t), hs in heads_by_rt.items(): | |||
k = len(hs) | |||
H_head[r].update(math.log(k), 1.0) | |||
head_queries.append((r, t)) | |||
# --- relations ---------------------------------------------------------- | |||
H_rel_accum = _Accum() | |||
for (h, t), rs in rels_by_ht.items(): | |||
k = len(rs) | |||
H_rel_accum.update(math.log(k), 1.0) | |||
rel_queries.append((h, t)) | |||
H_struct_tail = {r: acc.mean for r, acc in H_tail.items()} | |||
H_struct_head = {r: acc.mean for r, acc in H_head.items()} | |||
# H_struct_rel = H_rel_accum.mean # scalar | |||
# --------------------------------------------------------------------- # | |||
# KL divergences # | |||
# --------------------------------------------------------------------- # | |||
def _avg_kl_per_relation( | |||
queries: list[tuple[int, int]], | |||
true_idx_map: dict[tuple[int, int], list[int]], | |||
distribution_fn: Callable[[Tensor], Tensor], | |||
num_relations: bool = False, # switch to know output size | |||
) -> dict[int, float]: | |||
""" | |||
Compute *average* KL(q || p) **per relation**. | |||
Works for the tail & head conditionals. | |||
""" | |||
kl_sum: dict[int, float] = defaultdict(float) | |||
count: dict[int, int] = defaultdict(int) | |||
query_tensor = torch.tensor(queries, dtype=torch.long, device=device) | |||
for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size): | |||
# slice_size handled inside distribution_fn | |||
# log_p : (B, |Candidates|) | |||
b = log_p.shape[0] | |||
# Map back to current slice of queries | |||
slice_queries = queries[:b] | |||
queries[:] = queries[b:] # consume | |||
true_sets = [true_idx_map[q] for q in slice_queries] | |||
kl_row = _kl_uniform(log_p, true_sets) # (B,) | |||
for (x, r), kl_val in zip(slice_queries, kl_row.tolist()): | |||
rel_id = r if not num_relations else x | |||
kl_sum[rel_id] += kl_val | |||
count[rel_id] += 1 | |||
return {r: kl_sum[r] / count[r] for r in kl_sum} | |||
# --- tails -------------------------------------------------------------- | |||
tail_queries_copy = tail_queries.copy() | |||
D_tail = _avg_kl_per_relation( | |||
tail_queries_copy, | |||
tails_by_hr, | |||
lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s), | |||
) | |||
S_tail = {r: math.exp(-d) for r, d in D_tail.items()} | |||
# --- heads -------------------------------------------------------------- | |||
head_queries_copy = head_queries.copy() | |||
D_head = _avg_kl_per_relation( | |||
head_queries_copy, | |||
heads_by_rt, | |||
lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s), | |||
num_relations=True, | |||
) | |||
S_head = {r: math.exp(-d) for r, d in D_head.items()} | |||
# --- relations (single global number) ---------------------------------- | |||
D_rel_sum = 0.0 | |||
for log_p in _log_prob_distribution( | |||
lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s), | |||
torch.tensor(rel_queries, dtype=torch.long, device=device), | |||
batch_size, | |||
): | |||
slice_size_batch = log_p.shape[0] | |||
slice_queries = rel_queries[:slice_size_batch] | |||
rel_queries[:] = rel_queries[slice_size_batch:] | |||
true_sets = [rels_by_ht[q] for q in slice_queries] | |||
D_rel_sum += _kl_uniform(log_p, true_sets).sum().item() | |||
D_rel = D_rel_sum / H_rel_accum.tot_w | |||
S_rel = math.exp(-D_rel) | |||
# --------------------------------------------------------------------- # | |||
# Weighted aggregation → C-SWKLF # | |||
# --------------------------------------------------------------------- # | |||
def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float: | |||
num = sum(weight[r] * score[r] for r in score) | |||
den = sum(weight[r] for r in score) | |||
return num / den if den else float("nan") | |||
sw_tail = _weighted_mean(S_tail, H_struct_tail) | |||
sw_head = _weighted_mean(S_head, H_struct_head) | |||
# Relations: numerator & denominator cancel | |||
sw_rel = S_rel | |||
return alpha * sw_tail + beta * sw_head + gamma * sw_rel |
@@ -0,0 +1,732 @@ | |||
# from __future__ import annotations | |||
# import math | |||
# import os | |||
# import random | |||
# import threading | |||
# from collections import defaultdict | |||
# from concurrent.futures import ThreadPoolExecutor, as_completed | |||
# from typing import Dict, List, Tuple | |||
# import torch | |||
# from data.kg_dataset import KGDataset | |||
# from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in | |||
# from tools import get_pretty_logger | |||
# logger = get_pretty_logger(__name__) | |||
# # -------------------------------------------------------------------- | |||
# # Edge-editing primitives | |||
# # -------------------------------------------------------------------- | |||
# # -- additions -------------------- | |||
# def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||
# for _ in range(1000): | |||
# h = rng.randrange(n_ent) | |||
# t = rng.randrange(n_ent) | |||
# if colours[h] != colours[t]: | |||
# r = rng.randrange(n_rel) | |||
# triples.append((h, r, t)) | |||
# out_adj[h].append((r, t)) | |||
# in_adj[t].append((r, h)) | |||
# return | |||
# def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||
# σ = rng.choice(colours) | |||
# τ = rng.choice(colours) | |||
# h = colours.index(σ) | |||
# t = colours.index(τ) | |||
# triples.append((h, 0, t)) | |||
# out_adj[h].append((0, t)) | |||
# in_adj[t].append((0, h)) | |||
# # -- removals -------------------- | |||
# def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
# cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||
# if cand: | |||
# idx = rng.choice(cand) | |||
# h, r, t = triples.pop(idx) | |||
# out_adj[h].remove((r, t)) | |||
# in_adj[t].remove((r, h)) | |||
# def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
# sig_rel = defaultdict(set) | |||
# for h, r, t in triples: | |||
# σ, τ = colours[h], colours[t] | |||
# sig_rel[(σ, τ)].add(r) | |||
# cand = [] | |||
# for i, (h, r, t) in enumerate(triples): | |||
# if len(sig_rel[(colours[h], colours[t])]) == 1: | |||
# cand.append(i) | |||
# if cand: | |||
# idx = rng.choice(cand) | |||
# h, r, t = triples.pop(idx) | |||
# out_adj[h].remove((r, t)) | |||
# in_adj[t].remove((r, h)) | |||
# def _search_worker( | |||
# seed: int, | |||
# triples_init: List[Tuple[int, int, int]], | |||
# triples_factory, | |||
# n_ent: int, | |||
# n_rel: int, | |||
# depth: int, | |||
# lo: float, | |||
# hi: float, | |||
# max_iters: int, | |||
# ) -> List[Tuple[int, int, int]] | None: | |||
# """ | |||
# Run the exact same hill‑climb that `tune_crec()` did, | |||
# but entirely in this process. Return the edited triples | |||
# once c falls in [lo, hi]; return None if we used up all iterations. | |||
# """ | |||
# rng = random.Random(seed) | |||
# triples = triples_init.copy().tolist() | |||
# for it in range(max_iters): | |||
# # WL‑CREC is *only* recomputed every 1000 edits, exactly like before | |||
# if it % 1000 == 0: | |||
# dataset = KGDataset( | |||
# triples_factory, | |||
# triples=torch.tensor(triples, dtype=torch.long), | |||
# num_entities=n_ent, | |||
# num_relations=n_rel, | |||
# ) | |||
# wl = WLCREC(dataset) | |||
# colours = wl.wl_colours(depth) | |||
# *_, c, _ = wl.compute(H=5) # unchanged API | |||
# if lo <= c <= hi: # success | |||
# logger.info("[seed %d] hit %.4f in %d edits", seed, c, it) | |||
# return triples | |||
# # ---------- identical edit logic ---------- | |||
# if c < lo: | |||
# if rng.random() < 0.5: | |||
# rem_det_edge(triples, colours, depth, rng) | |||
# else: | |||
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||
# else: # c > hi | |||
# if rng.random() < 0.5: | |||
# rem_div_edge(triples, colours, depth, rng) | |||
# else: | |||
# add_det_edge(triples, colours, depth, n_rel, rng) | |||
# # ------------------------------------------ | |||
# return None # used up our budget | |||
# # -------------------------------------------------------------------- | |||
# # Unified tuner | |||
# # -------------------------------------------------------------------- | |||
# def tune_crec( | |||
# wl_crec: WLCREC, | |||
# n_ent: int, | |||
# n_rel: int, | |||
# depth: int, | |||
# lo: float, | |||
# hi: float, | |||
# max_iters: int = 80_000, | |||
# seed: int = 42, | |||
# ): | |||
# triples = wl_crec.dataset.triples.tolist() | |||
# rng = random.Random(seed) | |||
# for it in range(max_iters): | |||
# print(f"\r[iter {it + 1:5d}] ", end="") | |||
# if it % 1000 == 0: | |||
# dataset = KGDataset( | |||
# wl_crec.dataset.triples_factory, | |||
# triples=torch.tensor(triples, dtype=torch.long), | |||
# num_entities=n_ent, | |||
# num_relations=n_rel, | |||
# ) | |||
# tmp_wl_crec = WLCREC(dataset) | |||
# # colours = wl_colours(triples, n_ent, depth) | |||
# colours = tmp_wl_crec.wl_colours(depth) | |||
# _, _, _, _, c, _ = tmp_wl_crec.compute(H=5) | |||
# if lo <= c <= hi: | |||
# logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples)) | |||
# return triples | |||
# if c < lo: | |||
# # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying | |||
# if rng.random() < 0.5: | |||
# rem_det_edge(triples, colours, depth, rng) | |||
# else: | |||
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng) | |||
# # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic | |||
# elif rng.random() < 0.5: | |||
# rem_div_edge(triples, colours, depth, rng) | |||
# else: | |||
# add_det_edge(triples, colours, depth, n_rel, rng) | |||
# if (it + 1) % 10000 == 1: | |||
# logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples)) | |||
# raise RuntimeError("Exceeded max iterations without hitting target band.") | |||
# def _edit_batch( | |||
# worker_id: int, | |||
# n_edits: int, | |||
# colours: List[int], | |||
# crec: float, | |||
# target_lo: float, | |||
# target_hi: float, | |||
# depth: int, | |||
# n_ent: int, | |||
# n_rel: int, | |||
# seed_base: int, | |||
# ): | |||
# """ | |||
# Perform `n_edits` topology modifications *locally* and return | |||
# (added_triples, removed_triples) lists. | |||
# """ | |||
# rng = random.Random(seed_base + worker_id) | |||
# local_added, local_removed = [], [] | |||
# # local graph views: only sizes matter for edit selection | |||
# for _ in range(n_edits): | |||
# if crec < target_lo: # need ↑ WL-CREC | |||
# if rng.random() < 0.5: # • remove deterministic | |||
# # We cannot remove by index safely without the global list; | |||
# # choose a *signature* and remember intention to delete: | |||
# local_removed.append(("det", rng.random())) | |||
# else: # • add diversifying | |||
# h = rng.randrange(n_ent) | |||
# t = rng.randrange(n_ent) | |||
# while colours[h] == colours[t]: | |||
# h = rng.randrange(n_ent) | |||
# t = rng.randrange(n_ent) | |||
# r = rng.randrange(n_rel) | |||
# local_added.append((h, r, t)) | |||
# else: # need ↓ WL-CREC | |||
# if rng.random() < 0.5: # • remove diversifying | |||
# local_removed.append(("div", rng.random())) | |||
# else: # • add deterministic | |||
# σ = rng.choice(colours) | |||
# τ = rng.choice(colours) | |||
# h = colours.index(σ) | |||
# t = colours.index(τ) | |||
# local_added.append((h, 0, t)) | |||
# return local_added, local_removed | |||
# def wl_colours(triples, n_ent, depth): | |||
# # build adjacency once | |||
# out_adj, in_adj = defaultdict(list), defaultdict(list) | |||
# for h, r, t in triples: | |||
# out_adj[h].append((r, t)) | |||
# in_adj[t].append((r, h)) | |||
# colours_rounds = [[0] * n_ent] # round-0 colours | |||
# for h in range(1, depth + 1): | |||
# prev = colours_rounds[-1] | |||
# # 1) build textual signatures in parallel | |||
# def sig(v): | |||
# neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||
# ("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||
# ] | |||
# neigh.sort() | |||
# return (prev[v], tuple(neigh)) | |||
# with ThreadPoolExecutor() as tpe: # cheap threads inside worker | |||
# sigs = list(tpe.map(sig, range(n_ent))) | |||
# # 2) assign deterministic colour IDs | |||
# sig2id: Dict[Tuple, int] = {} | |||
# next_round = [0] * n_ent | |||
# fresh = 0 | |||
# for v, sg in enumerate(sigs): | |||
# cid = sig2id.setdefault(sg, fresh) | |||
# if cid == fresh: | |||
# fresh += 1 | |||
# next_round[v] = cid | |||
# colours_rounds.append(next_round) | |||
# depth_colours = colours_rounds[-1] | |||
# return depth_colours | |||
# def _metric_worker(args): | |||
# triples, n_ent, n_rel, depth = args | |||
# dataset = KGDataset( | |||
# triples_factory=None, # Not used in this context | |||
# triples=torch.tensor(triples, dtype=torch.long), | |||
# num_entities=n_ent, | |||
# num_relations=n_rel, | |||
# ) | |||
# wl_crec = WLCREC(dataset) | |||
# _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False) | |||
# colours = wl_colours(triples, n_ent, depth) | |||
# return c, colours | |||
# def tune_crec_parallel_edits( | |||
# triples_init, | |||
# n_ent: int, | |||
# n_rel: int, | |||
# depth: int, | |||
# target_lo: float, | |||
# target_hi: float, | |||
# max_iters: int = 80_000, | |||
# metric_every: int = 100, | |||
# n_workers: int = max(20, math.ceil(os.cpu_count() / 2)), | |||
# seed: int = 42, | |||
# ): | |||
# # -------- shared mutable state (main thread owns it) -------- | |||
# triples = triples_init.tolist() | |||
# out_adj, in_adj = defaultdict(list), defaultdict(list) | |||
# for h, r, t in triples: | |||
# out_adj[h].append((r, t)) | |||
# in_adj[t].append((r, h)) | |||
# pool = ThreadPoolExecutor(max_workers=n_workers) | |||
# metric_lock = threading.Lock() # exactly one metric at a time | |||
# rng_global = random.Random(seed) | |||
# # ----- first metric checkpoint ----- | |||
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||
# edit_budget_total = 0 | |||
# for it in range(0, max_iters, metric_every): | |||
# # ========================================================= | |||
# # 1. PARALLEL EDIT STAGE (metric_every edits in total) | |||
# # ========================================================= | |||
# futures = [] | |||
# edits_per_worker = metric_every // n_workers | |||
# extra = metric_every % n_workers | |||
# for wid in range(n_workers): | |||
# n_edits = edits_per_worker + (1 if wid < extra else 0) | |||
# futures.append( | |||
# pool.submit( | |||
# _edit_batch, | |||
# wid, | |||
# n_edits, | |||
# colours, | |||
# crec, | |||
# target_lo, | |||
# target_hi, | |||
# depth, | |||
# n_ent, | |||
# n_rel, | |||
# seed, | |||
# ) | |||
# ) | |||
# # merge when workers finish | |||
# for fut in as_completed(futures): | |||
# added, removed_specs = fut.result() | |||
# # --- apply additions immediately (cheap, conflict-free) | |||
# for h, r, t in added: | |||
# triples.append((h, r, t)) | |||
# out_adj[h].append((r, t)) | |||
# in_adj[t].append((r, h)) | |||
# # --- apply removals: interpret the spec on *current* graph | |||
# for typ, randv in removed_specs: | |||
# if typ == "div": | |||
# idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||
# else: # 'det' | |||
# sig_rel = defaultdict(set) | |||
# for h, r, t in triples: | |||
# sig_rel[(colours[h], colours[t])].add(r) | |||
# idxs = [ | |||
# i | |||
# for i, (h, r, t) in enumerate(triples) | |||
# if len(sig_rel[(colours[h], colours[t])]) == 1 | |||
# ] | |||
# if idxs: | |||
# victim = idxs[int(randv * len(idxs))] | |||
# h, r, t = triples.pop(victim) | |||
# out_adj[h].remove((r, t)) | |||
# in_adj[t].remove((r, h)) | |||
# edit_budget_total += metric_every | |||
# # ========================================================= | |||
# # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised) | |||
# # ========================================================= | |||
# # if edit_budget_total % (10 * metric_every) == 0: | |||
# if True: | |||
# with metric_lock: | |||
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth)) | |||
# logging.info( | |||
# "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples) | |||
# ) | |||
# if target_lo <= crec <= target_hi: | |||
# logging.info("Target band reached.") | |||
# pool.shutdown(wait=True) | |||
# return triples | |||
# pool.shutdown(wait=True) | |||
# raise RuntimeError("Exceeded max iteration budget without success.") | |||
# from __future__ import annotations | |||
# import logging | |||
# import random | |||
# from collections import defaultdict | |||
# from concurrent.futures import ThreadPoolExecutor, as_completed | |||
# from typing import List, Tuple | |||
# import torch | |||
# from data.kg_dataset import KGDataset | |||
# from metrics.wlcrec import WLCREC | |||
# from tools import get_pretty_logger | |||
# logger = get_pretty_logger(__name__) | |||
# # ------------ additions ------------ | |||
# def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random): | |||
# for _ in range(1000): | |||
# h, t = rng.randrange(n_ent), rng.randrange(n_ent) | |||
# if colours[h] != colours[t]: | |||
# r = rng.randrange(n_rel) | |||
# return ("add", (h, r, t)) | |||
# return None # fell through – extremely rare | |||
# def propose_add_det(colours, n_rel: int, rng: random.Random): | |||
# σ, τ = rng.choice(colours), rng.choice(colours) | |||
# h, t = colours.index(σ), colours.index(τ) | |||
# return ("add", (h, 0, t)) # rel 0 = deterministic | |||
# # ------------ removals ------------ | |||
# def propose_rem_div(triples: List, colours, rng: random.Random): | |||
# cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]] | |||
# return ("rem", rng.choice(cand)) if cand else None | |||
# def propose_rem_det(triples: List, colours, rng: random.Random): | |||
# sig_rel = defaultdict(set) | |||
# for h, r, t in triples: | |||
# sig_rel[(colours[h], colours[t])].add(r) | |||
# cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1] | |||
# return ("rem", rng.choice(cand)) if cand else None | |||
# def make_edit_proposal( | |||
# triples_snapshot: List, | |||
# colours_snapshot, | |||
# c: float, | |||
# lo: float, | |||
# hi: float, | |||
# n_ent: int, | |||
# n_rel: int, | |||
# depth: int, # still here for future use / signatures | |||
# seed: int, | |||
# ): | |||
# """Return exactly one ('add' | 'rem', triple) proposal or None.""" | |||
# rng = random.Random(seed) | |||
# # -- decide which kind of edit we want, *given* the current c ------------ | |||
# if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying) | |||
# chooser = (propose_rem_det, propose_add_div) | |||
# else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic) | |||
# chooser = (propose_rem_div, propose_add_det) | |||
# op = rng.choice(chooser) | |||
# return ( | |||
# op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng) | |||
# if op.__name__.startswith("propose_add") | |||
# else op(triples_snapshot, colours_snapshot, rng) | |||
# ) | |||
# def tune_crec_parallel_edits( | |||
# triples: List, | |||
# n_ent: int, | |||
# n_rel: int, | |||
# depth: int, | |||
# lo: float, | |||
# hi: float, | |||
# *, | |||
# max_iters: int = 80_000, | |||
# edits_per_eval: int = 1000, # == old “if it % 1000 == 0” | |||
# batch_size: int = 256, # how many proposals we farm out at once | |||
# max_workers: int = 4, | |||
# seed: int = 42, | |||
# ) -> List: | |||
# rng_global = random.Random(seed) | |||
# triples = set(triples) # deduplicate if needed | |||
# with ThreadPoolExecutor(max_workers=max_workers) as pool: | |||
# proposal_seed = seed * 997 # deterministic but different stream | |||
# # edit_counter = 0 | |||
# # while edit_counter < max_iters: | |||
# # # ----------------- expensive part (single‑thread) ---------------- | |||
# # dataset = KGDataset( | |||
# # triples_factory=None, # Not used in this context | |||
# # triples=torch.tensor(triples, dtype=torch.long), | |||
# # num_entities=n_ent, | |||
# # num_relations=n_rel, | |||
# # ) | |||
# # tmp = WLCREC(dataset) | |||
# # colours = tmp.wl_colours(depth) | |||
# # *_, c, _ = tmp.compute(H=5) | |||
# # # ----------------------------------------------------------------- | |||
# # if lo <= c <= hi: | |||
# # logger.info( | |||
# # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples) | |||
# # ) | |||
# # return triples | |||
# # ============ parallel block: just make `edits_per_eval` proposals | |||
# needed = min(edits_per_eval, max_iters - edit_counter) | |||
# proposals = [] | |||
# while len(proposals) < needed: | |||
# # launch a batch of workers | |||
# futs = [ | |||
# pool.submit( | |||
# make_edit_proposal, | |||
# triples, | |||
# colours, | |||
# c, | |||
# lo, | |||
# hi, | |||
# n_ent, | |||
# n_rel, | |||
# depth, | |||
# proposal_seed + i, | |||
# ) | |||
# for i in range(batch_size) | |||
# ] | |||
# for f in as_completed(futs): | |||
# prop = f.result() | |||
# if prop is not None: | |||
# proposals.append(prop) | |||
# if len(proposals) == needed: | |||
# break | |||
# proposal_seed += batch_size # move RNG window forward | |||
# # -------------- apply the gathered proposals *sequentially* ------- | |||
# for kind, trp in proposals: | |||
# if kind == "add": | |||
# triples.append(trp) | |||
# else: # "rem" | |||
# try: | |||
# triples.remove(trp) | |||
# except ValueError: | |||
# pass # already gone – benign collision | |||
# # ----------------------------------------------------------------- | |||
# edit_counter += needed | |||
# if edit_counter % 1_000 == 0: | |||
# logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples)) | |||
# raise RuntimeError("Exceeded max_iters without hitting target band.") | |||
import os | |||
import random | |||
import threading | |||
from collections import defaultdict | |||
from typing import Dict, List, Tuple | |||
# -------------------------------------------------------------------- | |||
# Logging helper (replace with your own if you prefer) -------------- | |||
# -------------------------------------------------------------------- | |||
from tools import get_pretty_logger | |||
logging = get_pretty_logger(__name__) | |||
# -------------------------------------------------------------------- | |||
# Original edge‑editing primitives (unchanged) ---------------------- | |||
# -------------------------------------------------------------------- | |||
def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng): | |||
for _ in range(1000): | |||
h = rng.randrange(n_ent) | |||
t = rng.randrange(n_ent) | |||
if colours[h] != colours[t]: | |||
r = rng.randrange(n_rel) | |||
triples.append((h, r, t)) | |||
out_adj[h].append((r, t)) | |||
in_adj[t].append((r, h)) | |||
return | |||
def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng): | |||
σ = rng.choice(colours) | |||
τ = rng.choice(colours) | |||
h = colours.index(σ) | |||
t = colours.index(τ) | |||
triples.append((h, 0, t)) | |||
out_adj[h].append((0, t)) | |||
in_adj[t].append((0, h)) | |||
def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]] | |||
if cand: | |||
idx = rng.choice(cand) | |||
h, r, t = triples.pop(idx) | |||
out_adj[h].remove((r, t)) | |||
in_adj[t].remove((r, h)) | |||
def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng): | |||
sig_rel = defaultdict(set) | |||
for h, r, t in triples: | |||
σ, τ = colours[h], colours[t] | |||
sig_rel[(σ, τ)].add(r) | |||
cand = [] | |||
for i, (h, r, t) in enumerate(triples): | |||
if len(sig_rel[(colours[h], colours[t])]) == 1: | |||
cand.append(i) | |||
if cand: | |||
idx = rng.choice(cand) | |||
h, r, t = triples.pop(idx) | |||
out_adj[h].remove((r, t)) | |||
in_adj[t].remove((r, h)) | |||
def _worker( | |||
*, | |||
worker_id: int, | |||
rng: random.Random, | |||
triples: List[Tuple[int, int, int]], | |||
out_adj: Dict[int, List[Tuple[int, int]]], | |||
in_adj: Dict[int, List[Tuple[int, int]]], | |||
colours: List[int], | |||
n_ent: int, | |||
n_rel: int, | |||
depth: int, | |||
c: float, | |||
lo: float, | |||
hi: float, | |||
max_iters: int, | |||
state_lock: threading.Lock, | |||
stop_event: threading.Event, | |||
): | |||
"""One thread: mutate the *shared* structures until success/stop.""" | |||
for it in range(max_iters): | |||
if stop_event.is_set(): | |||
return # someone else finished ─ exit early | |||
with state_lock: # protect the shared graph | |||
# if lo <= c <= hi: | |||
# logging.info( | |||
# "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)", | |||
# worker_id, | |||
# it, | |||
# c, | |||
# len(triples), | |||
# ) | |||
# stop_event.set() | |||
# return | |||
# Choose and apply one edit ----------------------------- | |||
if c < lo: # need ↑ CREC | |||
if rng.random() < 0.5: | |||
rem_det_edge(triples, out_adj, in_adj, colours, depth, rng) | |||
else: | |||
add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng) | |||
elif rng.random() < 0.5: | |||
rem_div_edge(triples, out_adj, in_adj, colours, depth, rng) | |||
else: | |||
add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng) | |||
logging.warning("[worker %d] reached max_iters", worker_id) | |||
stop_event.set() | |||
return | |||
# -------------------------------------------------------------------- | |||
# Public API -------------------------------------------------------- | |||
# -------------------------------------------------------------------- | |||
def fast_tune_crec( | |||
triples: List, | |||
colours: List, | |||
n_ent: int, | |||
n_rel: int, | |||
depth: int, | |||
c: float, | |||
lo: float, | |||
hi: float, | |||
max_iters: int = 1000, | |||
max_workers: int | None = None, | |||
seeds: List[int] | None = None, | |||
) -> List[Tuple[int, int, int]]: | |||
"""Tune WL‑CREC with *shared* triples using multiple threads. | |||
Returns the **same list instance** that was passed in – already | |||
modified in place by the winning thread. | |||
""" | |||
if max_workers is None: | |||
# max_workers = os.cpu_count() or 4 | |||
max_workers = 4 | |||
if seeds is None: | |||
seeds = [42 + i for i in range(max_workers)] | |||
assert len(seeds) >= max_workers, "Need at least one seed per worker" | |||
# Prepare adjacency once (shared) -------------------------------- | |||
out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||
in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list) | |||
for h, r, t in triples: | |||
out_adj[h].append((r, t)) | |||
in_adj[t].append((r, h)) | |||
state_lock = threading.Lock() | |||
stop_event = threading.Event() | |||
logging.info( | |||
"Launching %d threads on shared triples (target %.3f–%.3f)", | |||
max_workers, | |||
lo, | |||
hi, | |||
) | |||
threads: List[threading.Thread] = [] | |||
for wid in range(max_workers): | |||
t = threading.Thread( | |||
name=f"tune‑crec‑worker‑{wid}", | |||
target=_worker, | |||
kwargs=dict( | |||
worker_id=wid, | |||
rng=random.Random(seeds[wid]), | |||
triples=triples, | |||
out_adj=out_adj, | |||
in_adj=in_adj, | |||
colours=colours, | |||
n_ent=n_ent, | |||
n_rel=n_rel, | |||
depth=depth, | |||
c=c, | |||
lo=lo, | |||
hi=hi, | |||
max_iters=max_iters, | |||
state_lock=state_lock, | |||
stop_event=stop_event, | |||
), | |||
daemon=False, | |||
) | |||
threads.append(t) | |||
t.start() | |||
for t in threads: | |||
t.join() | |||
if not stop_event.is_set(): | |||
raise RuntimeError("No thread converged – try increasing max_iters or widen band") | |||
return triples |
@@ -0,0 +1,106 @@ | |||
from __future__ import annotations | |||
import logging | |||
import random | |||
from collections import defaultdict, deque | |||
from math import isclose | |||
from typing import List | |||
import torch | |||
from data.kg_dataset import KGDataset | |||
from metrics.wlcrec import WLCREC | |||
from tools import get_pretty_logger | |||
logging = get_pretty_logger(__name__) | |||
# -------------------------------------------------------------------- | |||
# Radius-k patch sampler | |||
# -------------------------------------------------------------------- | |||
def radius_patch_sample( | |||
all_triples: List, | |||
ratio: float, | |||
radius: int, | |||
rng: random.Random, | |||
) -> List: | |||
""" | |||
Take ~ratio fraction of triples by picking random seeds and | |||
keeping *all* triples within ≤ radius hops (undirected) of them. | |||
""" | |||
n_total = len(all_triples) | |||
target = int(ratio * n_total) | |||
# Build entity → triple-indices adjacency for BFS | |||
ent2trip = defaultdict(list) | |||
for idx, (h, _, t) in enumerate(all_triples): | |||
ent2trip[h].append(idx) | |||
ent2trip[t].append(idx) | |||
chosen_idx: set[int] = set() | |||
visited_ent = set() | |||
while len(chosen_idx) < target: | |||
seed_idx = rng.randrange(n_total) | |||
if seed_idx in chosen_idx: | |||
continue | |||
# BFS over entities | |||
queue = deque() | |||
for e in all_triples[seed_idx][::2]: # subject & object | |||
if e not in visited_ent: | |||
queue.append((e, 0)) | |||
visited_ent.add(e) | |||
while queue: | |||
ent, dist = queue.popleft() | |||
for tidx in ent2trip[ent]: | |||
if tidx not in chosen_idx: | |||
chosen_idx.add(tidx) | |||
if dist < radius: | |||
for tidx in ent2trip[ent]: | |||
h, _, t = all_triples[tidx] | |||
for nb in (h, t): | |||
if nb not in visited_ent: | |||
visited_ent.add(nb) | |||
queue.append((nb, dist + 1)) | |||
# guard against infinite loop on very small ratio | |||
if isclose(ratio, 0.0) and chosen_idx: | |||
break | |||
return [all_triples[i] for i in chosen_idx] | |||
# -------------------------------------------------------------------- | |||
# Main driver: repeated sampling until WL-CREC band hit | |||
# -------------------------------------------------------------------- | |||
def find_sample( | |||
all_triples: List, | |||
n_ent: int, | |||
n_rel: int, | |||
depth: int, | |||
ratio: float, | |||
radius: int, | |||
lower: float, | |||
upper: float, | |||
max_tries: int, | |||
seed: int, | |||
) -> List: | |||
for trial in range(1, max_tries + 1): | |||
rng = random.Random(seed + trial) # fresh randomness each try | |||
sample = radius_patch_sample(all_triples, ratio, radius, rng) | |||
sample_ds = KGDataset( | |||
triples_factory=None, | |||
triples=torch.tensor(sample, dtype=torch.long), | |||
num_entities=n_ent, | |||
num_relations=n_rel, | |||
split="train", | |||
) | |||
wl_crec = WLCREC(sample_ds) | |||
crec_val = wl_crec.compute(depth)[-2] | |||
logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val) | |||
if lower <= crec_val <= upper: | |||
logging.info("Success after %d tries.", trial) | |||
return sample | |||
raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.") |
@@ -0,0 +1,226 @@ | |||
from __future__ import annotations | |||
import logging | |||
import random | |||
from math import log | |||
from typing import Dict, List, Tuple | |||
import torch | |||
from torch import Tensor | |||
from tools import get_pretty_logger | |||
logging = get_pretty_logger(__name__) | |||
Entity = int | |||
Relation = int | |||
Triple = Tuple[Entity, Relation, Entity] | |||
Color = int | |||
# --------------------------------------------------------------------- | |||
# Weisfeiler–Lehman colouring (GPU friendly) | |||
# --------------------------------------------------------------------- | |||
def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]: | |||
""" | |||
GPU implementation of classic Weisfeiler‑Lehman refinement. | |||
1. Each node starts with colour 0. | |||
2. At iteration h we hash the multiset of | |||
⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h)) | |||
signatures into new colour ids. | |||
3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and | |||
immaterial for CREC (only *relative* colours matter). | |||
""" | |||
h, r, t = triples.unbind(dim=1) # (|T|,) | |||
colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0 | |||
# pre‑compute 64‑bit mix constants once | |||
mix_r = ( | |||
torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long) | |||
* 1_146_189_683_093_321_123 | |||
) | |||
mix_c = 8_636_673_225_737_527_201 | |||
for _ in range(depth): | |||
prev = colours[-1] # (|V|,) | |||
# signatures for outgoing edges | |||
sig_out = mix_r[r] ^ (prev[t] * mix_c) | |||
# signatures for incoming edges | |||
sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15 | |||
# bucket signatures back to the source/target nodes | |||
col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out) | |||
col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in) | |||
# combine ↓ and ↑ multiset hashes with current colour | |||
raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1) | |||
# re‑map to dense consecutive ids with torch.unique | |||
uniq, new = torch.unique(raw, sorted=True, return_inverse=True) | |||
colours.append(new) | |||
return colours # length = depth+1, each (|V|,) long | |||
# --------------------------------------------------------------------- | |||
# WL‑CREC metric | |||
# --------------------------------------------------------------------- | |||
def wl_crec_gpu( | |||
triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device | |||
) -> float: | |||
log_n = log(max(n_ent, 2)) | |||
# ------- pattern‑diversity C ------------------------------------- | |||
C = 0.0 | |||
for c in colours: # (|V|,) | |||
# histogram via bincount on GPU | |||
hist = torch.bincount(c).float() | |||
p = hist / n_ent | |||
C += -(p * torch.log(p.clamp_min(1e-30))).sum().item() | |||
C /= (depth + 1) * log_n | |||
# ------- residual entropy H_c ------------------------------------ | |||
h, r, t = triples.unbind(dim=1) | |||
sigma, tau = colours[depth][h], colours[depth][t] # (|T|,) | |||
# encode (sigma,tau) pairs into a single 64‑bit key | |||
sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64) | |||
# total count per signature | |||
sig_unique, sig_inv, sig_counts = torch.unique( | |||
sig_keys, return_inverse=True, return_counts=True, sorted=False | |||
) | |||
# build 2‑D contingency table counts[(sigma,tau), r] | |||
m = triples.size(0) | |||
keys_2d = sig_inv * n_rel + r | |||
_, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False) | |||
# rel_counts is aligned with the *compact* key list; rebuild dense tensor | |||
rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long) | |||
rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long)) | |||
# conditional entropy per (sigma,tau) | |||
p_r = rc_dense.float() / sig_counts.unsqueeze(1).float() | |||
inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,) | |||
Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2)) | |||
return C * Hc | |||
# --------------------------------------------------------------------- | |||
# delta‑entropy helpers (still CPU side for clarity; cost is negligible) | |||
# --------------------------------------------------------------------- | |||
# The deterministic delta‑formulas depend on small dictionaries and are evaluated | |||
# on ≤256 candidate edges each iteration – copy them from the original script. | |||
inner_entropy = ... # unchanged | |||
delta_h_cond_remove = ... # unchanged | |||
delta_h_cond_add = ... # unchanged | |||
# --------------------------------------------------------------------- | |||
# Greedy delta‑search driver | |||
# --------------------------------------------------------------------- | |||
def greedy_delta_tune_gpu( | |||
triples: List[Triple], | |||
n_ent: int, | |||
n_rel: int, | |||
depth: int, | |||
lower: float, | |||
upper: float, | |||
device: torch.device, | |||
max_iters: int = 40_000, | |||
sample_size: int = 256, | |||
seed: int = 0, | |||
) -> List[Triple]: | |||
# book‑keeping in Python lists (cheap); heavy math in torch on GPU | |||
rng = random.Random(seed) | |||
# cache torch version of triples for fast metric eval | |||
def triples_to_tensor(ts: List[Triple]) -> Tensor: | |||
return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False) | |||
for it in range(max_iters): | |||
tri_tensor = triples_to_tensor(triples) | |||
colours = wl_colours_gpu(tri_tensor, n_ent, depth, device) | |||
crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device) | |||
if lower <= crec <= upper: | |||
logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples)) | |||
return triples | |||
# ---------------------------------------------------------------- | |||
# build signature statistics (CPU, cheap: O(|T|)) | |||
# ---------------------------------------------------------------- | |||
depth_col = colours[depth].cpu() | |||
sig_cnt: Dict[Tuple[int, int], int] = {} | |||
rel_cnt: Dict[Tuple[int, int, int], int] = {} | |||
for h_idx, (h, r, t) in enumerate(triples): | |||
sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||
sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1 | |||
rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1 | |||
det_edges, div_edges = [], [] | |||
for idx, (h, r, t) in enumerate(triples): | |||
sigma, tau = int(depth_col[h]), int(depth_col[t]) | |||
if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1: | |||
det_edges.append(idx) | |||
if sigma != tau: | |||
div_edges.append(idx) | |||
# ---------------------------------------------------------------- | |||
# candidate generation + best edit selection | |||
# ---------------------------------------------------------------- | |||
total = len(triples) | |||
target_high = crec < lower # need to raise WL‑CREC? | |||
candidates = [] | |||
if target_high and det_edges: | |||
rng.shuffle(det_edges) | |||
for idx in det_edges[:sample_size]: | |||
h, r, t = triples[idx] | |||
sig = (int(depth_col[h]), int(depth_col[t])) | |||
delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||
if delta > 0: | |||
candidates.append(("remove", idx, delta)) | |||
elif not target_high and det_edges: | |||
rng.shuffle(det_edges) | |||
for idx in det_edges[:sample_size]: | |||
h, r, t = triples[idx] | |||
sig = (int(depth_col[h]), int(depth_col[t])) | |||
delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel) | |||
if delta < 0: | |||
candidates.append(("add", (h, r, t), delta)) | |||
# fall‑back heuristics | |||
if not candidates: | |||
if target_high and div_edges: | |||
idx = rng.choice(div_edges) | |||
candidates.append(("remove", idx, 1e-9)) | |||
elif not target_high: | |||
idx = rng.choice(det_edges) if det_edges else rng.randrange(total) | |||
h, r, t = triples[idx] | |||
candidates.append(("add", (h, r, t), -1e-9)) | |||
best = ( | |||
max(candidates, key=lambda x: x[2]) | |||
if target_high | |||
else min(candidates, key=lambda x: x[2]) | |||
) | |||
# apply edit | |||
if best[0] == "remove": | |||
triples.pop(best[1]) | |||
else: | |||
triples.append(best[1]) | |||
if (it + 1) % 1_000 == 0: | |||
logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples)) | |||
raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.") |
@@ -0,0 +1,30 @@ | |||
import torch | |||
from models.base_model import ERModel | |||
def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |||
return (scores > scores[target]).sum() + 1 | |||
def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor: | |||
h, r, t = triples[:, 0], triples[:, 1], triples[:, 2] | |||
E = model.num_entities | |||
ranks = [] | |||
with torch.no_grad(): | |||
for hi, ri, ti in zip(h, r, t): | |||
tails = torch.arange(E, device=triples.device) | |||
triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1) | |||
scores = model(triples) | |||
ranks.append(_rank(scores, ti)) | |||
return torch.tensor(ranks, device=triples.device, dtype=torch.float32) | |||
def mrr(r: torch.Tensor) -> torch.Tensor: | |||
return (1 / r).mean() | |||
def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor: | |||
return (r <= k).float().mean() |
@@ -0,0 +1,237 @@ | |||
from __future__ import annotations | |||
import math | |||
from collections import Counter, defaultdict | |||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||
import torch | |||
from .base_metric import BaseMetric | |||
if TYPE_CHECKING: | |||
from data.kg_dataset import KGDataset | |||
def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||
"""Return Shannon entropy (nats) of a colour distribution of length *n*.""" | |||
ent = 0.0 | |||
for cnt in counter.values(): | |||
if cnt: | |||
p = cnt / n | |||
ent -= p * math.log(p) | |||
return ent | |||
def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||
"""Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies.""" | |||
if n <= 1: | |||
return 0.0, 0.0 | |||
H_plus_1 = len(entropies) | |||
log_n = math.log(n) | |||
c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1 | |||
c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1 | |||
return c_ratio, c_nwlec | |||
def _relation_families( | |||
triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9 | |||
) -> torch.Tensor: | |||
"""Return a LongTensor mapping each relation id → family id. | |||
Two relations *r, r_inv* are put into the same family when **at least | |||
`thresh` fraction** of the edges labelled *r* have a **single** reverse edge | |||
labelled *r_inv*. | |||
""" | |||
assert 0.0 < thresh <= 1.0 | |||
# Build map (h,t) -> list of relations on that edge direction. | |||
edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list) | |||
for h, r, t in triples.tolist(): | |||
edge2rels[(h, t)].append(int(r)) | |||
# Count reciprocal co‑occurrences (r, r_rev). | |||
pair_counts: Dict[Tuple[int, int], int] = defaultdict(int) | |||
rel_totals: List[int] = [0] * num_relations | |||
for (h, t), rels in edge2rels.items(): | |||
rev_rels = edge2rels.get((t, h)) | |||
if not rev_rels: | |||
continue | |||
for r in rels: | |||
rel_totals[r] += 1 | |||
for r_rev in rev_rels: | |||
pair_counts[(r, r_rev)] += 1 | |||
# Proposed inverse mapping r -> r_inv. | |||
proposed: List[int | None] = [None] * num_relations | |||
for (r, r_rev), cnt in pair_counts.items(): | |||
if cnt == 0: | |||
continue | |||
if cnt / rel_totals[r] >= thresh: | |||
# majority of r's edges are reversed by r_rev | |||
proposed[r] = r_rev | |||
# Build family ids (transitive closure is overkill; data are simple). | |||
family = list(range(num_relations)) | |||
for r, rinv in enumerate(proposed): | |||
if rinv is not None: | |||
root = min(family[r], family[rinv]) | |||
family[r] = family[rinv] = root | |||
return torch.tensor(family, dtype=torch.long) | |||
def _conditional_relation_entropy( | |||
colours: List[int], | |||
triples: torch.Tensor, | |||
family_map: torch.Tensor, | |||
) -> float: | |||
counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int)) | |||
m = int(triples.size(0)) | |||
fam = family_map # alias for speed | |||
for h, r, t in triples.tolist(): | |||
# key = (colours[h], colours[t]) # ordered key (direction‑aware) | |||
key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic) | |||
counts[key][int(fam[r])] += 1 # ▼ use family id | |||
h_cond = 0.0 | |||
for rel_counts in counts.values(): | |||
total_s = sum(rel_counts.values()) | |||
P_s = total_s / m | |||
inv_total = 1.0 / total_s | |||
for cnt in rel_counts.values(): | |||
p_rs = cnt * inv_total | |||
h_cond -= P_s * p_rs * math.log(p_rs) | |||
return h_cond | |||
class WLCREC(BaseMetric): | |||
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||
def __init__(self, dataset: KGDataset) -> None: | |||
super().__init__(dataset) | |||
def compute( | |||
self, | |||
H: int = 3, | |||
cond_h: int = 1, | |||
inv_thresh: float = 0.9, | |||
return_full: bool = False, | |||
progress: bool = True, | |||
) -> float | Tuple[float, List[float]]: | |||
"""Compute WL-entropy scores and the composite difficulty metric. | |||
Parameters | |||
---------- | |||
dataset | |||
Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``. | |||
H | |||
WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers. | |||
return_full | |||
Also return the list of per-layer entropies. | |||
progress | |||
Print textual progress bar. | |||
Returns | |||
------- | |||
avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC | |||
*Six* scalars (floats). If ``return_full=True`` an extra list of | |||
layer entropies is appended. | |||
""" | |||
# compute relation family map once | |||
family_map = _relation_families( | |||
self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh | |||
) | |||
# ─── WL iterations ──────────────────────────────────────────────────── | |||
n = self.dataset.num_entities | |||
colour: List[int] = [1] * n # colour 1 for everyone | |||
entropies: List[float] = [] | |||
colours_per_h: List[List[int]] = [] | |||
def _print_prog(i: int) -> None: | |||
if progress: | |||
print(f"\rWL iteration {i}/{H}", end="", flush=True) | |||
for h in range(H + 1): | |||
_print_prog(h) | |||
colours_per_h.append(colour.copy()) | |||
# entropy of current colouring | |||
freq = Counter(colour) | |||
entropies.append(_entropy_from_counter(freq, n)) | |||
if h == H: | |||
break | |||
# refine colours | |||
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||
next_colour: List[int] = [0] * n | |||
for v in range(n): | |||
T: List[Tuple[int, int, int]] = [] | |||
for r, t in self.out_adj[v]: | |||
T.append((r, 0, colour[t])) # outgoing | |||
for r, h_ in self.in_adj[v]: | |||
T.append((r, 1, colour[h_])) # incoming | |||
T.sort() | |||
key = (colour[v], tuple(T)) | |||
if key not in bucket: | |||
bucket[key] = len(bucket) + 1 | |||
next_colour[v] = bucket[key] | |||
colour = next_colour | |||
_print_prog(H) | |||
if progress: | |||
print() | |||
# ─── Normalised diversity ──────────────────────────────────────────── | |||
avg_entropy = sum(entropies) / len(entropies) | |||
c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||
# ─── Residual relation entropy ─────────────────────────────────────── | |||
h_cond = _conditional_relation_entropy( | |||
colours_per_h[cond_h], self.dataset.triples, family_map | |||
) | |||
# ─── Composite difficulty - Eq. (1) ────────────────────────────────── | |||
d_ratio = c_ratio * h_cond | |||
d_nwlec = c_nwlec * h_cond | |||
result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec) | |||
if return_full: | |||
result += (entropies,) | |||
return result | |||
# -------------------------------------------------------------------- | |||
# Weisfeiler–Lehman colours (unchanged) | |||
# -------------------------------------------------------------------- | |||
def wl_colours(self, depth: int) -> List[List[int]]: | |||
triples = self.dataset.triples.tolist() | |||
out_adj, in_adj = defaultdict(list), defaultdict(list) | |||
for h, r, t in triples: | |||
out_adj[h].append((r, t)) | |||
in_adj[t].append((r, h)) | |||
n_ent = self.dataset.num_entities | |||
colours = [[0] * n_ent] # round-0 | |||
for h in range(1, depth + 1): | |||
prev, nxt, sig2c = colours[-1], [0] * n_ent, {} | |||
fresh = 0 | |||
for v in range(n_ent): | |||
neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [ | |||
("↑", r, prev[u]) for r, u in in_adj.get(v, []) | |||
] | |||
neigh.sort() | |||
sig = (prev[v], tuple(neigh)) | |||
if sig not in sig2c: | |||
sig2c[sig] = fresh | |||
fresh += 1 | |||
nxt[v] = sig2c[sig] | |||
colours.append(nxt) | |||
return colours |
@@ -0,0 +1,135 @@ | |||
from __future__ import annotations | |||
import math | |||
from collections import Counter | |||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||
from .base_metric import BaseMetric | |||
if TYPE_CHECKING: | |||
from data.kg_dataset import KGDataset | |||
def _entropy_from_counter(counter: Counter[int], n: int) -> float: | |||
"""Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts.""" | |||
ent = 0.0 | |||
for count in counter.values(): | |||
if count: # avoid log(0) | |||
p = count / n | |||
ent -= p * math.log(p) | |||
return ent | |||
def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]: | |||
"""Return (C_ratio, C_NWLEC) given *per-iteration* entropies.""" | |||
if n <= 1: | |||
# Degenerate graph. | |||
return 0.0, 0.0 | |||
H_plus_1 = len(entropies) | |||
# Eq. (1): simple ratio normalisation. | |||
c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1 | |||
# Eq. (2): effective-colour NWLEC. | |||
k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies] | |||
c_nwlec = sum(k_terms) / H_plus_1 | |||
return c_ratio, c_nwlec | |||
class WLEC(BaseMetric): | |||
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph.""" | |||
def __init__(self, dataset: KGDataset) -> None: | |||
super().__init__(dataset) | |||
def compute( | |||
self, | |||
H: int = 3, | |||
*, | |||
return_full: bool = False, | |||
progress: bool = True, | |||
) -> float | Tuple[float, List[float]]: | |||
"""Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC). | |||
Parameters | |||
---------- | |||
dataset: | |||
Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``), | |||
``num_entities`` and ``num_relations``. | |||
H: | |||
Number of refinement iterations **after** the initial colouring (so the | |||
colours are updated **H** times and the entropy is measured **H+1** times). | |||
return_full: | |||
If *True*, additionally return the list of entropies for each depth. | |||
progress: | |||
Whether to print a mini progress bar (no external dependency). | |||
Returns | |||
------- | |||
average_entropy or (average_entropy, entropies) | |||
Average of the stored entropies, and optionally the list itself. | |||
""" | |||
n = int(self.num_entities) | |||
if n == 0: | |||
msg = "Dataset appears to contain zero entities." | |||
raise ValueError(msg) | |||
# Colour assignment - we keep it as a *list* of ints for cheap hashing. | |||
# Start with colour 1 for every entity (index 0 unused). | |||
colour: List[int] = [1] * n | |||
entropies: List[float] = [] | |||
# Optional poor-man's progress bar. | |||
def _print_prog(i: int) -> None: | |||
if progress: | |||
print(f"\rIteration {i}/{H}", end="", flush=True) | |||
for h in range(H + 1): | |||
_print_prog(h) | |||
# ------------------------------------------------------------------ | |||
# Step 1: measure entropy of current colouring ---------------------- | |||
# ------------------------------------------------------------------ | |||
freq = Counter(colour) | |||
ent = _entropy_from_counter(freq, n) | |||
entropies.append(ent) | |||
if h == H: | |||
break # We have reached the requested depth - stop here. | |||
# ------------------------------------------------------------------ | |||
# Step 2: create refined colours ----------------------------------- | |||
# ------------------------------------------------------------------ | |||
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {} | |||
next_colour: List[int] = [0] * n | |||
for v in range(n): | |||
# Build the multiset T containing outgoing and incoming features. | |||
T: List[Tuple[int, int, int]] = [] | |||
for r, t in self.out_adj[v]: | |||
T.append((r, 0, colour[t])) # 0 = outgoing (\u2193) | |||
for r, h_ in self.in_adj[v]: | |||
T.append((r, 1, colour[h_])) # 1 = incoming (\u2191) | |||
T.sort() # canonical ordering | |||
key = (colour[v], tuple(T)) | |||
if key not in bucket: | |||
bucket[key] = len(bucket) + 1 # new colour ID (start at 1) | |||
next_colour[v] = bucket[key] | |||
colour = next_colour # move to next iteration | |||
_print_prog(H) | |||
if progress: | |||
print() # newline after progress bar | |||
average_entropy = sum(entropies) / len(entropies) | |||
c_ratio, c_nwlec = _normalise_scores(entropies, n) | |||
if return_full: | |||
return average_entropy, entropies, c_ratio, c_nwlec | |||
return average_entropy, c_ratio, c_nwlec |
@@ -0,0 +1,127 @@ | |||
from __future__ import annotations | |||
from abc import ABC, abstractmethod | |||
from typing import TYPE_CHECKING | |||
import torch | |||
from torch import nn | |||
from torch.nn.functional import log_softmax, softmax | |||
if TYPE_CHECKING: | |||
from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target | |||
class ERModel(nn.Module, ABC): | |||
"""Base class for knowledge graph models.""" | |||
def __init__(self, num_entities: int, num_relations: int, dim: int) -> None: | |||
super().__init__() | |||
self.num_entities = num_entities | |||
self.num_relations = num_relations | |||
self.dim = dim | |||
self.inner_model = None | |||
@property | |||
def device(self) -> torch.device: | |||
"""Return the device of the model.""" | |||
return next(self.inner_model.parameters()).device | |||
@abstractmethod | |||
def reset_parameters(self, *args, **kwargs) -> None: | |||
"""Reset the parameters of the model.""" | |||
def predict( | |||
self, | |||
hrt_batch: MappedTriples, | |||
target: Target, | |||
full_batch: bool = True, | |||
ids: LongTensor | None = None, | |||
**kwargs, | |||
) -> FloatTensor: | |||
if not self.inner_model: | |||
msg = ( | |||
"Inner model is not set. Please initialize the inner model before calling predict." | |||
) | |||
raise ValueError( | |||
msg, | |||
) | |||
return self.inner_model.predict( | |||
hrt_batch=hrt_batch, | |||
target=target, | |||
full_batch=full_batch, | |||
ids=ids, | |||
**kwargs, | |||
) | |||
def forward( | |||
self, | |||
triples: MappedTriples, | |||
slice_size: int | None = None, | |||
slice_dim: int = 0, | |||
*, | |||
mode: InductiveMode | None, | |||
) -> FloatTensor: | |||
h_indices = triples[:, 0] | |||
r_indices = triples[:, 1] | |||
t_indices = triples[:, 2] | |||
if not self.inner_model: | |||
msg = ( | |||
"Inner model is not set. Please initialize the inner model before calling forward." | |||
) | |||
raise ValueError( | |||
msg, | |||
) | |||
return self.inner_model.forward( | |||
h_indices=h_indices, | |||
r_indices=r_indices, | |||
t_indices=t_indices, | |||
slice_size=slice_size, | |||
slice_dim=slice_dim, | |||
mode=mode, | |||
) | |||
@torch.inference_mode() | |||
def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r) | |||
self, | |||
hr_batch: LongTensor, # (B, 2) : (h,r) | |||
*, | |||
slice_size: int | None = None, | |||
mode: InductiveMode | None = None, | |||
log: bool = False, | |||
) -> FloatTensor: # (B, |E|) | |||
scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode) | |||
return log_softmax(scores, -1) if log else softmax(scores, -1) | |||
# ----------------------------------------------------------------------- # | |||
# (1) HEAD-given-(relation, tail) → pθ(h | r,t) # | |||
# ----------------------------------------------------------------------- # | |||
@torch.inference_mode() | |||
def head_distribution( | |||
self, | |||
rt_batch: LongTensor, # (B, 2) : (r,t) | |||
*, | |||
slice_size: int | None = None, | |||
mode: InductiveMode | None = None, | |||
log: bool = False, | |||
) -> FloatTensor: # (B, |E|) | |||
scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode) | |||
return log_softmax(scores, -1) if log else softmax(scores, -1) | |||
# ----------------------------------------------------------------------- # | |||
# (2) RELATION-given-(head, tail) → pθ(r | h,t) # | |||
# ----------------------------------------------------------------------- # | |||
@torch.inference_mode() | |||
def relation_distribution( | |||
self, | |||
ht_batch: LongTensor, # (B, 2) : (h,t) | |||
*, | |||
slice_size: int | None = None, | |||
mode: InductiveMode | None = None, | |||
log: bool = False, | |||
) -> FloatTensor: # (B, |R|) | |||
scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode) | |||
return log_softmax(scores, -1) if log else softmax(scores, -1) |
@@ -0,0 +1,70 @@ | |||
from __future__ import annotations | |||
from typing import TYPE_CHECKING | |||
import torch | |||
from pykeen.models import TransE as PyKEENTransE | |||
from torch import nn | |||
from models.base_model import ERModel | |||
if TYPE_CHECKING: | |||
from pykeen.typing import MappedTriples | |||
class TransE(ERModel): | |||
""" | |||
TransE model for knowledge graph embedding. | |||
score = ||h + r - t|| | |||
""" | |||
def __init__( | |||
self, | |||
num_entities: int, | |||
num_relations: int, | |||
triples_factory: MappedTriples, | |||
dim: int = 200, | |||
sf_norm: int = 1, | |||
p_norm: bool = True, | |||
margin: float | None = None, | |||
epsilon: float | None = None, | |||
device: torch.device = torch.device("cpu"), | |||
) -> None: | |||
super().__init__(num_entities, num_relations, dim) | |||
self.dim = dim | |||
self.sf_norm = sf_norm | |||
self.p_norm = p_norm | |||
self.margin = margin | |||
self.epsilon = epsilon | |||
self.inner_model = PyKEENTransE( | |||
triples_factory=triples_factory, | |||
embedding_dim=dim, | |||
scoring_fct_norm=sf_norm, | |||
power_norm=p_norm, | |||
random_seed=42, | |||
).to(device) | |||
def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||
# Parameter initialization | |||
if margin is None or epsilon is None: | |||
# If no margin/epsilon are provided, use Xavier uniform initialization | |||
nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||
nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||
else: | |||
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||
self.embedding_range = nn.Parameter( | |||
torch.Tensor([(margin + epsilon) / self.dim]), | |||
requires_grad=False, | |||
) | |||
nn.init.uniform_( | |||
tensor=self.ent_embeddings.weight.data, | |||
a=-self.embedding_range.item(), | |||
b=self.embedding_range.item(), | |||
) | |||
nn.init.uniform_( | |||
tensor=self.rel_embeddings.weight.data, | |||
a=-self.embedding_range.item(), | |||
b=self.embedding_range.item(), | |||
) |
@@ -0,0 +1,76 @@ | |||
from __future__ import annotations | |||
from typing import TYPE_CHECKING | |||
import torch | |||
from pykeen.losses import MarginRankingLoss | |||
from pykeen.models import TransH as PyKEENTransH | |||
from pykeen.regularizers import LpRegularizer | |||
from torch import nn | |||
from models.base_model import ERModel | |||
if TYPE_CHECKING: | |||
from pykeen.typing import MappedTriples | |||
class TransH(ERModel): | |||
""" | |||
TransH model for knowledge graph embedding. | |||
""" | |||
def __init__( | |||
self, | |||
num_entities: int, | |||
num_relations: int, | |||
triples_factory: MappedTriples, | |||
dim: int = 200, | |||
p_norm: bool = True, | |||
margin: float | None = None, | |||
epsilon: float | None = None, | |||
device: torch.device = torch.device("cpu"), | |||
**kwargs, | |||
) -> None: | |||
super().__init__(num_entities, num_relations, dim) | |||
self.dim = dim | |||
self.p_norm = p_norm | |||
self.margin = margin | |||
self.epsilon = epsilon | |||
self.inner_model = PyKEENTransH( | |||
triples_factory=triples_factory, | |||
embedding_dim=dim, | |||
scoring_fct_norm=1, | |||
loss=self.loss, | |||
power_norm=p_norm, | |||
entity_initializer="xavier_uniform_", # good default | |||
relation_initializer="xavier_uniform_", | |||
random_seed=42, | |||
).to(device) | |||
@property | |||
def loss(self) -> torch.Tensor: | |||
return MarginRankingLoss(margin=self.margin, reduction="mean") | |||
def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||
# Parameter initialization | |||
if margin is None or epsilon is None: | |||
# If no margin/epsilon are provided, use Xavier uniform initialization | |||
nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||
nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||
else: | |||
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||
self.embedding_range = nn.Parameter( | |||
torch.Tensor([(margin + epsilon) / self.dim]), | |||
requires_grad=False, | |||
) | |||
nn.init.uniform_( | |||
tensor=self.ent_embeddings.weight.data, | |||
a=-self.embedding_range.item(), | |||
b=self.embedding_range.item(), | |||
) | |||
nn.init.uniform_( | |||
tensor=self.rel_embeddings.weight.data, | |||
a=-self.embedding_range.item(), | |||
b=self.embedding_range.item(), | |||
) |
@@ -0,0 +1,79 @@ | |||
from __future__ import annotations | |||
from typing import TYPE_CHECKING | |||
import torch | |||
from pykeen.losses import MarginRankingLoss | |||
from pykeen.models import TransR as PyKEENTransR | |||
from torch import nn | |||
from models.base_model import ERModel | |||
if TYPE_CHECKING: | |||
from pykeen.typing import MappedTriples | |||
class TransR(ERModel): | |||
""" | |||
TransR model for knowledge graph embedding. | |||
""" | |||
def __init__( | |||
self, | |||
num_entities: int, | |||
num_relations: int, | |||
triples_factory: MappedTriples, | |||
entity_dim: int = 200, | |||
relation_dim: int = 30, | |||
p_norm: bool = True, | |||
p_norm_value: int = 2, | |||
margin: float | None = None, | |||
epsilon: float | None = None, | |||
device: torch.device = torch.device("cpu"), | |||
**kwargs, | |||
) -> None: | |||
super().__init__(num_entities, num_relations, entity_dim) | |||
self.entity_dim = entity_dim | |||
self.relation_dim = relation_dim | |||
self.p_norm = p_norm | |||
self.margin = margin | |||
self.epsilon = epsilon | |||
self.inner_model = PyKEENTransR( | |||
triples_factory=triples_factory, | |||
embedding_dim=entity_dim, | |||
relation_dim=relation_dim, | |||
scoring_fct_norm=1, | |||
loss=self.loss, | |||
power_norm=p_norm, | |||
entity_initializer="xavier_uniform_", | |||
relation_initializer="xavier_uniform_", | |||
random_seed=42, | |||
).to(device) | |||
@property | |||
def loss(self) -> torch.Tensor: | |||
return MarginRankingLoss(margin=self.margin, reduction="mean") | |||
def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None: | |||
# Parameter initialization | |||
if margin is None or epsilon is None: | |||
# If no margin/epsilon are provided, use Xavier uniform initialization | |||
nn.init.xavier_uniform_(self.ent_embeddings.weight.data) | |||
nn.init.xavier_uniform_(self.rel_embeddings.weight.data) | |||
else: | |||
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ] | |||
self.embedding_range = nn.Parameter( | |||
torch.Tensor([(margin + epsilon) / self.dim]), | |||
requires_grad=False, | |||
) | |||
nn.init.uniform_( | |||
tensor=self.ent_embeddings.weight.data, | |||
a=-self.embedding_range.item(), | |||
b=self.embedding_range.item(), | |||
) | |||
nn.init.uniform_( | |||
tensor=self.rel_embeddings.weight.data, | |||
a=-self.embedding_range.item(), | |||
b=self.embedding_range.item(), | |||
) |
@@ -0,0 +1,82 @@ | |||
[tool.ruff] | |||
# Exclude a variety of commonly ignored directories. | |||
exclude = [ | |||
".bzr", | |||
".direnv", | |||
".eggs", | |||
".git", | |||
".git-rewrite", | |||
".hg", | |||
".ipynb_checkpoints", | |||
".mypy_cache", | |||
".nox", | |||
".pants.d", | |||
".pyenv", | |||
".pytest_cache", | |||
".pytype", | |||
".ruff_cache", | |||
".svn", | |||
".tox", | |||
".venv", | |||
".vscode", | |||
"__pypackages__", | |||
"_build", | |||
"buck-out", | |||
"build", | |||
"dist", | |||
"node_modules", | |||
"site-packages", | |||
"venv", | |||
] | |||
respect-gitignore = true | |||
# Black line length is 88 | |||
line-length = 100 | |||
indent-width = 4 | |||
[tool.ruff.lint] | |||
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. | |||
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or | |||
# McCabe complexity (`C901`) by default. | |||
select = ["ALL"] | |||
ignore = ["ANN002", "ANN003", "RET504", "ERA001", "E741", "N812", "D", "PLR0913", "T201", "FBT001", "FBT002", "PLW0602", "PLW0603", "PTH", "C901", "PLR2004", "UP006", "UP035", "G004", "N803", "N806", "ARG002"] | |||
[tool.ruff.lint.per-file-ignores] | |||
"__init__.py" = ["E", "F", "I", "N"] | |||
# Allow fix for all enabled rules (when `--fix`) is provided. | |||
fixable = ["ALL"] | |||
unfixable = [] | |||
# Allow unused variables when underscore-prefixed. | |||
#dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" | |||
[tool.ruff.format] | |||
# Like Black, use double quotes for strings. | |||
quote-style = "double" | |||
# Like Black, indent with spaces, rather than tabs. | |||
indent-style = "space" | |||
# Like Black, respect magic trailing commas. | |||
skip-magic-trailing-comma = false | |||
# Like Black, automatically detect the appropriate line ending. | |||
line-ending = "auto" | |||
# Enable auto-formatting of code examples in docstrings. Markdown, | |||
# reStructuredText code/literal blocks and doctests are all supported. | |||
# | |||
# This is currently disabled by default, but it is planned for this | |||
# to be opt-out in the future. | |||
docstring-code-format = false | |||
# Set the line length limit used when formatting code snippets in | |||
# docstrings. | |||
# | |||
# This only has an effect when the `docstring-code-format` setting is | |||
# enabled. | |||
docstring-code-line-length = "dynamic" | |||
[tool.pyright] | |||
typeCheckingMode = "off" |
@@ -0,0 +1,28 @@ | |||
#!/usr/bin/env bash | |||
# install.sh – minimal Conda bootstrap with ~/.bashrc check | |||
set -e | |||
ENV=kg-env # must match your environment.yml | |||
YAML=kg_env.yaml # assumed to be in the current dir | |||
# ── 1. try to load an existing conda from ~/.bashrc ────────────────────── | |||
[ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" || true | |||
# ── 2. ensure conda exists ─────────────────────────────────────────────── | |||
if ! command -v conda &>/dev/null; then | |||
echo "[install.sh] Conda not found → installing Miniconda." | |||
curl -fsSL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh | |||
bash miniconda.sh -b -p "$HOME/miniconda3" | |||
rm miniconda.sh | |||
source "$HOME/miniconda3/etc/profile.d/conda.sh" | |||
conda init bash >/dev/null # so future shells have it | |||
else | |||
# conda exists; load its helper for this shell | |||
source "$(conda info --base)/etc/profile.d/conda.sh" | |||
fi | |||
# ── 3. create or update the project env ────────────────────────────────── | |||
conda env create -n "$ENV" -f "$YAML" \ | |||
|| conda env update -n "$ENV" -f "$YAML" --prune | |||
echo "✔ Done. Activate with: conda activate $ENV" |
@@ -0,0 +1,163 @@ | |||
# environment.yml ── KG-Embeddings project | |||
name: kg | |||
channels: | |||
- pytorch | |||
- defaults | |||
- nvidia | |||
- conda-forge | |||
- https://repo.anaconda.com/pkgs/main | |||
- https://repo.anaconda.com/pkgs/r | |||
dependencies: | |||
- _libgcc_mutex=0.1=conda_forge | |||
- _openmp_mutex=4.5=2_gnu | |||
- aom=3.6.1=h59595ed_0 | |||
- blas=1.0=mkl | |||
- brotli-python=1.1.0=py311hfdbb021_2 | |||
- bzip2=1.0.8=h4bc722e_7 | |||
- ca-certificates=2024.12.14=hbcca054_0 | |||
- certifi=2024.12.14=pyhd8ed1ab_0 | |||
- cffi=1.17.1=py311hf29c0ef_0 | |||
- charset-normalizer=3.4.1=pyhd8ed1ab_0 | |||
- cpython=3.11.11=py311hd8ed1ab_1 | |||
- cuda-cudart=12.4.127=0 | |||
- cuda-cupti=12.4.127=0 | |||
- cuda-libraries=12.4.1=0 | |||
- cuda-nvrtc=12.4.127=0 | |||
- cuda-nvtx=12.4.127=0 | |||
- cuda-opencl=12.6.77=0 | |||
- cuda-runtime=12.4.1=0 | |||
- cuda-version=12.6=3 | |||
- ffmpeg=4.4.2=gpl_hdf48244_113 | |||
- filelock=3.16.1=pyhd8ed1ab_1 | |||
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0 | |||
- font-ttf-inconsolata=3.000=h77eed37_0 | |||
- font-ttf-source-code-pro=2.038=h77eed37_0 | |||
- font-ttf-ubuntu=0.83=h77eed37_3 | |||
- fontconfig=2.15.0=h7e30c49_1 | |||
- fonts-conda-ecosystem=1=0 | |||
- fonts-conda-forge=1=0 | |||
- freetype=2.12.1=h267a509_2 | |||
- gettext=0.22.5=he02047a_3 | |||
- gettext-tools=0.22.5=he02047a_3 | |||
- giflib=5.2.2=hd590300_0 | |||
- gmp=6.3.0=hac33072_2 | |||
- gmpy2=2.1.5=py311h0f6cedb_3 | |||
- gnutls=3.7.9=hb077bed_0 | |||
- h2=4.1.0=pyhd8ed1ab_1 | |||
- hpack=4.0.0=pyhd8ed1ab_1 | |||
- hyperframe=6.0.1=pyhd8ed1ab_1 | |||
- idna=3.10=pyhd8ed1ab_1 | |||
- intel-openmp=2022.0.1=h06a4308_3633 | |||
- jinja2=3.1.5=pyhd8ed1ab_0 | |||
- lame=3.100=h166bdaf_1003 | |||
- lcms2=2.16=hb7c19ff_0 | |||
- ld_impl_linux-64=2.43=h712a8e2_2 | |||
- lerc=4.0.0=h27087fc_0 | |||
- libasprintf=0.22.5=he8f35ee_3 | |||
- libasprintf-devel=0.22.5=he8f35ee_3 | |||
- libblas=3.9.0=16_linux64_mkl | |||
- libcblas=3.9.0=16_linux64_mkl | |||
- libcublas=12.4.5.8=0 | |||
- libcufft=11.2.1.3=0 | |||
- libcufile=1.11.1.6=0 | |||
- libcurand=10.3.7.77=0 | |||
- libcusolver=11.6.1.9=0 | |||
- libcusparse=12.3.1.170=0 | |||
- libdeflate=1.23=h4ddbbb0_0 | |||
- libdrm=2.4.124=hb9d3cd8_0 | |||
- libegl=1.7.0=ha4b6fd6_2 | |||
- libexpat=2.6.4=h5888daf_0 | |||
- libffi=3.4.2=h7f98852_5 | |||
- libgcc=14.2.0=h77fa898_1 | |||
- libgcc-ng=14.2.0=h69a702a_1 | |||
- libgettextpo=0.22.5=he02047a_3 | |||
- libgettextpo-devel=0.22.5=he02047a_3 | |||
- libgl=1.7.0=ha4b6fd6_2 | |||
- libglvnd=1.7.0=ha4b6fd6_2 | |||
- libglx=1.7.0=ha4b6fd6_2 | |||
- libgomp=14.2.0=h77fa898_1 | |||
- libiconv=1.17=hd590300_2 | |||
- libidn2=2.3.7=hd590300_0 | |||
- libjpeg-turbo=3.0.0=hd590300_1 | |||
- liblapack=3.9.0=16_linux64_mkl | |||
- liblzma=5.6.3=hb9d3cd8_1 | |||
- libnpp=12.2.5.30=0 | |||
- libnsl=2.0.1=hd590300_0 | |||
- libnvfatbin=12.6.77=0 | |||
- libnvjitlink=12.4.127=0 | |||
- libnvjpeg=12.3.1.117=0 | |||
- libpciaccess=0.18=hd590300_0 | |||
- libpng=1.6.45=h943b412_0 | |||
- libsqlite=3.47.2=hee588c1_0 | |||
- libstdcxx=14.2.0=hc0a3c3a_1 | |||
- libstdcxx-ng=14.2.0=h4852527_1 | |||
- libtasn1=4.19.0=h166bdaf_0 | |||
- libtiff=4.7.0=hd9ff511_3 | |||
- libunistring=0.9.10=h7f98852_0 | |||
- libuuid=2.38.1=h0b41bf4_0 | |||
- libva=2.22.0=h8a09558_1 | |||
- libvpx=1.13.1=h59595ed_0 | |||
- libwebp=1.5.0=hae8dbeb_0 | |||
- libwebp-base=1.5.0=h851e524_0 | |||
- libxcb=1.17.0=h8a09558_0 | |||
- libxcrypt=4.4.36=hd590300_1 | |||
- libxml2=2.13.5=h0d44e9d_1 | |||
- libzlib=1.3.1=hb9d3cd8_2 | |||
- llvm-openmp=15.0.7=h0cdce71_0 | |||
- markupsafe=3.0.2=py311h2dc5d0c_1 | |||
- mkl=2022.1.0=hc2b9512_224 | |||
- mpc=1.3.1=h24ddda3_1 | |||
- mpfr=4.2.1=h90cbb55_3 | |||
- mpmath=1.3.0=pyhd8ed1ab_1 | |||
- ncurses=6.5=h2d0b736_2 | |||
- nettle=3.9.1=h7ab15ed_0 | |||
- networkx=3.4.2=pyh267e887_2 | |||
- numpy=2.2.1=py311hf916aec_0 | |||
- openh264=2.3.1=hcb278e6_2 | |||
- openjpeg=2.5.3=h5fbd93e_0 | |||
- openssl=3.4.0=h7b32b05_1 | |||
- p11-kit=0.24.1=hc5aa10d_0 | |||
- pillow=11.1.0=py311h1322bbf_0 | |||
- pip=24.3.1=pyh8b19718_2 | |||
- pthread-stubs=0.4=hb9d3cd8_1002 | |||
- pycparser=2.22=pyh29332c3_1 | |||
- pysocks=1.7.1=pyha55dd90_7 | |||
- python=3.11.11=h9e4cc4f_1_cpython | |||
- python_abi=3.11=5_cp311 | |||
- pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0 | |||
- pytorch-cuda=12.4=hc786d27_7 | |||
- pytorch-mutex=1.0=cuda | |||
- pyyaml=6.0.2=py311h9ecbd09_1 | |||
- readline=8.2=h8228510_1 | |||
- requests=2.32.3=pyhd8ed1ab_1 | |||
- setuptools=75.8.0=pyhff2d567_0 | |||
- svt-av1=1.4.1=hcb278e6_0 | |||
- sympy=1.13.3=pyh2585a3b_105 | |||
- tk=8.6.13=noxft_h4845f30_101 | |||
- torchaudio=2.5.1=py311_cu124 | |||
- torchtriton=3.1.0=py311 | |||
- torchvision=0.20.1=py311_cu124 | |||
- typing_extensions=4.12.2=pyha770c72_1 | |||
- tzdata=2024b=hc8b5060_0 | |||
- urllib3=2.3.0=pyhd8ed1ab_0 | |||
- wayland=1.23.1=h3e06ad9_0 | |||
- wayland-protocols=1.37=hd8ed1ab_0 | |||
- wheel=0.45.1=pyhd8ed1ab_1 | |||
- x264=1!164.3095=h166bdaf_2 | |||
- x265=3.5=h924138e_3 | |||
- xorg-libx11=1.8.10=h4f16b4b_1 | |||
- xorg-libxau=1.0.12=hb9d3cd8_0 | |||
- xorg-libxdmcp=1.1.5=hb9d3cd8_0 | |||
- xorg-libxext=1.3.6=hb9d3cd8_0 | |||
- xorg-libxfixes=6.0.1=hb9d3cd8_0 | |||
- yaml=0.2.5=h7f98852_2 | |||
- zstandard=0.23.0=py311hbc35293_1 | |||
- zstd=1.5.6=ha6fb4c9_0 | |||
- pip: | |||
- torchmetrics>=1.4 # nicer MRR/Hits@K utilities | |||
- lovely-tensors>=0.1.0 # tensor utilities | |||
- lovely-numpy>=0.1.0 # numpy utilities | |||
- einops==0.8.1 | |||
- pykeen==1.11.1 | |||
- hydra-core==1.3.2 | |||
- tensorboard==2.19.0 |
@@ -0,0 +1,6 @@ | |||
from .pretty_logger import get_pretty_logger | |||
from .sampling import DeterministicRandomSampler | |||
from .utils import set_seed | |||
from .tb_handler import TensorBoardHandler | |||
from .params import CommonParams, TrainingParams | |||
from .checkpoint_manager import CheckpointManager |
@@ -0,0 +1,221 @@ | |||
from __future__ import annotations | |||
import re | |||
from pathlib import Path | |||
class CheckpointManager: | |||
""" | |||
A checkpoint manager that handles checkpoint directory creation and path resolution. | |||
Features: | |||
- Creates sequentially numbered checkpoint directories | |||
- Supports two loading modes: by components or by full path | |||
""" | |||
def __init__( | |||
self, | |||
root_directory: str | Path, | |||
run_name: str, | |||
load_only: bool = False, | |||
) -> None: | |||
""" | |||
Initialize the checkpoint manager. | |||
Args: | |||
root_directory: Root directory for checkpoints | |||
run_name: Name of the run (used as suffix for checkpoint directories) | |||
""" | |||
self.root_directory = Path(root_directory) | |||
self.run_name = run_name | |||
self.checkpoint_directory = None | |||
if not load_only: | |||
self.root_directory.mkdir(parents=True, exist_ok=True) | |||
self.checkpoint_directory = self._create_checkpoint_directory() | |||
else: | |||
self.checkpoint_directory = "" | |||
def _find_existing_directories(self) -> list[int]: | |||
""" | |||
Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number. | |||
Returns: | |||
List of existing sequence numbers | |||
""" | |||
pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$") | |||
existing_numbers = [] | |||
if self.root_directory.exists(): | |||
for item in self.root_directory.iterdir(): | |||
if item.is_dir(): | |||
match = pattern.match(item.name) | |||
if match: | |||
existing_numbers.append(int(match.group(1))) | |||
return sorted(existing_numbers) | |||
def _create_checkpoint_directory(self) -> Path: | |||
""" | |||
Create a new checkpoint directory with the next sequential number. | |||
Returns: | |||
Path to the created checkpoint directory | |||
""" | |||
existing_numbers = self._find_existing_directories() | |||
# Determine the next number | |||
next_number = max(existing_numbers) + 1 if existing_numbers else 1 | |||
# Create directory name with 3-digit zero-padded number | |||
dir_name = f"{next_number:03d}_{self.run_name}" | |||
checkpoint_dir = self.root_directory / dir_name | |||
self.run_name = dir_name | |||
# Create the directory | |||
checkpoint_dir.mkdir(parents=True, exist_ok=True) | |||
return checkpoint_dir | |||
def get_checkpoint_directory(self) -> Path: | |||
""" | |||
Get the current checkpoint directory. | |||
Returns: | |||
Path to the checkpoint directory | |||
""" | |||
return self.checkpoint_directory | |||
def get_model_fpath(self, model_path: str) -> Path: | |||
""" | |||
Get the full path to a model checkpoint file. | |||
Args: | |||
model_path: Either a tuple (model_id, iteration) or a string path | |||
Returns: | |||
Full path to the model checkpoint file | |||
""" | |||
try: | |||
model_path = eval(model_path) | |||
if isinstance(model_path, tuple): | |||
model_id, iteration = model_path | |||
checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}" | |||
filename = f"model-{iteration}.pt" | |||
return Path(f"{checkpoint_directory}/{filename}") | |||
except (SyntaxError, NameError): | |||
pass | |||
if isinstance(model_path, str): | |||
return Path(model_path) | |||
msg = "model_path must be a tuple (model_id, iteration) or a string path" | |||
raise ValueError(msg) | |||
def get_model_path_by_args( | |||
self, | |||
model_id: str | None = None, | |||
iteration: int | str | None = None, | |||
full_path: str | Path | None = None, | |||
) -> Path: | |||
""" | |||
Get the path for loading a model checkpoint. | |||
Two modes of operation: | |||
1. Component mode: Provide model_id and iteration to construct path | |||
2. Full path mode: Provide full_path directly | |||
Args: | |||
model_id: Model identifier (used in component mode) | |||
iteration: Training iteration (used in component mode) | |||
full_path: Full path to checkpoint (used in full path mode) | |||
Returns: | |||
Path to the checkpoint file | |||
Raises: | |||
ValueError: If neither component parameters nor full_path are provided, | |||
or if both modes are attempted simultaneously | |||
""" | |||
# Check which mode we're operating in | |||
component_mode = model_id is not None or iteration is not None | |||
full_path_mode = full_path is not None | |||
if component_mode and full_path_mode: | |||
msg = ( | |||
"Cannot use both component mode (model_id/iteration) " | |||
"and full path mode simultaneously" | |||
) | |||
raise ValueError(msg) | |||
if not component_mode and not full_path_mode: | |||
msg = "Must provide either (model_id and iteration) or full_path" | |||
raise ValueError(msg) | |||
if full_path_mode: | |||
# Full path mode: return the path as-is | |||
return Path(full_path) | |||
# Component mode: construct path from checkpoint directory, model_id, and iteration | |||
if model_id is None or iteration is None: | |||
msg = "Both model_id and iteration must be provided in component mode" | |||
raise ValueError(msg) | |||
filename = f"{model_id}_iter_{iteration}.pt" | |||
return self.checkpoint_directory / filename | |||
def save_checkpoint_info(self, info: dict) -> None: | |||
""" | |||
Save checkpoint information to a JSON file in the checkpoint directory. | |||
Args: | |||
info: Dictionary containing checkpoint metadata | |||
""" | |||
import json | |||
info_file = self.checkpoint_directory / "checkpoint_info.json" | |||
with open(info_file, "w") as f: | |||
json.dump(info, f, indent=2) | |||
def load_checkpoint_info(self) -> dict: | |||
""" | |||
Load checkpoint information from the checkpoint directory. | |||
Returns: | |||
Dictionary containing checkpoint metadata | |||
""" | |||
import json | |||
info_file = self.checkpoint_directory / "checkpoint_info.json" | |||
if info_file.exists(): | |||
with open(info_file) as f: | |||
return json.load(f) | |||
return {} | |||
def __str__(self) -> str: | |||
return ( | |||
f"CheckpointManager(root='{self.root_directory}', " | |||
f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')" | |||
) | |||
def __repr__(self) -> str: | |||
return self.__str__() | |||
# Example usage: | |||
if __name__ == "__main__": | |||
import tempfile | |||
# Create checkpoint manager with proper temp directory | |||
with tempfile.TemporaryDirectory() as temp_dir: | |||
ckpt_manager = CheckpointManager(temp_dir, "my_experiment") | |||
print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}") | |||
# Component mode - construct path from components | |||
model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000) | |||
print(f"Model path (component mode): {model_path}") | |||
# Full path mode - use existing full path | |||
full_path = "/some/other/path/model.pt" | |||
model_path = ckpt_manager.get_model_path_by_args(full_path=full_path) | |||
print(f"Model path (full path mode): {model_path}") |
@@ -0,0 +1,45 @@ | |||
from __future__ import annotations | |||
from dataclasses import dataclass | |||
@dataclass | |||
class CommonParams: | |||
""" | |||
Common parameters for the application. | |||
""" | |||
model_name: str | |||
project_name: str | |||
run_name: str | |||
save_dpath: str | |||
save_every: int = 1000 | |||
load_path: str | None = "" | |||
log_dir: str = "./runs" | |||
log_every: int = 10 | |||
log_console_every: int = 100 | |||
evaluate_only: bool = False | |||
@dataclass | |||
class TrainingParams: | |||
""" | |||
Parameters for training. | |||
""" | |||
lr: float = 0.001 | |||
min_lr: float = 0.0001 | |||
weight_decay: float = 0.0 | |||
t0: int = 50 | |||
lr_step: int = 10 | |||
gamma: float = 1.0 | |||
batch_size: int = 128 | |||
num_workers: int = 0 | |||
seed: int = 42 | |||
num_train_steps: int = 100 | |||
num_epochs: int = 100 | |||
eval_every: int = 5 | |||
validation_sample_size: int = 1000 | |||
validation_batch_size: int = 128 |
@@ -0,0 +1,417 @@ | |||
# pretty_logger.py | |||
import json | |||
import logging | |||
import logging.handlers | |||
import sys | |||
import threading | |||
import time | |||
from contextlib import ContextDecorator | |||
from functools import wraps | |||
# ANSI escape codes for colors | |||
RESET = "\x1b[0m" | |||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = ( | |||
"\x1b[30m", | |||
"\x1b[31m", | |||
"\x1b[32m", | |||
"\x1b[33m", | |||
"\x1b[34m", | |||
"\x1b[35m", | |||
"\x1b[36m", | |||
"\x1b[37m", | |||
) | |||
_LEVEL_COLORS = { | |||
logging.DEBUG: CYAN, | |||
logging.INFO: GREEN, | |||
logging.WARNING: YELLOW, | |||
logging.ERROR: RED, | |||
logging.CRITICAL: MAGENTA, | |||
} | |||
def _supports_color() -> bool: | |||
""" | |||
Detect whether the running terminal supports ANSI colors. | |||
""" | |||
if hasattr(sys.stdout, "isatty") and sys.stdout.isatty(): | |||
platform = sys.platform | |||
if platform == "win32": | |||
# On Windows, modern terminals may support ANSI; assume yes for Windows 10+ | |||
return True | |||
return True | |||
return False | |||
class PrettyFormatter(logging.Formatter): | |||
""" | |||
A logging.Formatter that outputs colorized logs to the console | |||
(if enabled) and supports JSON formatting (if requested). | |||
""" | |||
def __init__( | |||
self, | |||
fmt: str = None, | |||
datefmt: str = "%Y-%m-%d %H:%M:%S", | |||
use_colors: bool = True, | |||
json_output: bool = False, | |||
): | |||
super().__init__(fmt=fmt, datefmt=datefmt) | |||
self.use_colors = use_colors and _supports_color() | |||
self.json_output = json_output | |||
def format(self, record: logging.LogRecord) -> str: | |||
# If JSON output is requested, dump a dict of relevant fields | |||
if self.json_output: | |||
log_record = { | |||
"timestamp": self.formatTime(record, self.datefmt), | |||
"name": record.name, | |||
"level": record.levelname, | |||
"message": record.getMessage(), | |||
} | |||
# Include extra/contextual fields | |||
for k, v in record.__dict__.get("extra_fields", {}).items(): | |||
log_record[k] = v | |||
if record.exc_info: | |||
log_record["exception"] = self.formatException(record.exc_info) | |||
return json.dumps(log_record, ensure_ascii=False) | |||
# Otherwise, build a colorized/“pretty” line | |||
level_color = _LEVEL_COLORS.get(record.levelno, WHITE) if self.use_colors else "" | |||
reset = RESET if self.use_colors else "" | |||
timestamp = self.formatTime(record, self.datefmt) | |||
name = record.name | |||
level = record.levelname | |||
# Basic format: [timestamp] [LEVEL] [name]: message | |||
base = f"[{timestamp}] [{level}] [{name}]: {record.getMessage()}" | |||
# If there are any extra fields attached via LoggerAdapter, append them | |||
extra_fields = record.__dict__.get("extra_fields", {}) | |||
if extra_fields: | |||
# key1=val1 key2=val2 … | |||
extras_str = " ".join(f"{k}={v}" for k, v in extra_fields.items()) | |||
base = f"{base} :: {extras_str}" | |||
# If exception info is present, include traceback | |||
if record.exc_info: | |||
exc_text = self.formatException(record.exc_info) | |||
base = f"{base}\n{exc_text}" | |||
if self.use_colors: | |||
return f"{level_color}{base}{reset}" | |||
else: | |||
return base | |||
class JsonFormatter(PrettyFormatter): | |||
""" | |||
Shortcut to force JSON formatting (ignores color flags). | |||
""" | |||
def __init__(self, datefmt: str = "%Y-%m-%d %H:%M:%S"): | |||
super().__init__(fmt=None, datefmt=datefmt, use_colors=False, json_output=True) | |||
class ContextFilter(logging.Filter): | |||
""" | |||
Inject extra_fields from a LoggerAdapter into the LogRecord, so the Formatter can find them. | |||
""" | |||
def filter(self, record: logging.LogRecord) -> bool: | |||
extra = getattr(record, "extra", None) | |||
if isinstance(extra, dict): | |||
record.extra_fields = extra | |||
else: | |||
record.extra_fields = {} | |||
return True | |||
class Timer(ContextDecorator): | |||
""" | |||
Context manager & decorator to measure execution time of a code block or function. | |||
Usage as context manager: | |||
with Timer(logger, "block name"): | |||
# code … | |||
Usage as decorator: | |||
@Timer(logger, "function foo") | |||
def foo(...): | |||
... | |||
""" | |||
def __init__(self, logger: logging.Logger, name: str = None, level: int = logging.INFO): | |||
self.logger = logger | |||
self.name = name or "" | |||
self.level = level | |||
def __enter__(self): | |||
self.start_time = time.time() | |||
return self | |||
def __exit__(self, exc_type, exc, exc_tb): | |||
elapsed = time.time() - self.start_time | |||
label = f"[Timer:{self.name}]" if self.name else "[Timer]" | |||
self.logger.log(self.level, f"{label} elapsed {elapsed:.4f} sec") | |||
return False # do not suppress exceptions | |||
def log_function(logger: logging.Logger, level: int = logging.DEBUG): | |||
""" | |||
Decorator factory to log function entry/exit and execution time. | |||
Usage: | |||
@log_function(logger, level=logging.INFO) | |||
def my_func(...): | |||
... | |||
""" | |||
def decorator(func): | |||
@wraps(func) | |||
def wrapper(*args, **kwargs): | |||
logger.log(level, f"➤ Entering {func.__name__}()") | |||
start = time.time() | |||
try: | |||
result = func(*args, **kwargs) | |||
except Exception: | |||
logger.exception(f"✖ Exception in {func.__name__}()") | |||
raise | |||
elapsed = time.time() - start | |||
logger.log(level, f"✔ Exiting {func.__name__}(), elapsed {elapsed:.4f} sec") | |||
return result | |||
return wrapper | |||
return decorator | |||
class PrettyLogger: | |||
""" | |||
Encapsulates a logging.Logger with “pretty” formatting, colorized console output, | |||
rotating file handler, JSON option, contextual fields, etc. | |||
""" | |||
_instances = {} | |||
_lock = threading.Lock() | |||
def __new__(cls, name: str = "root", **kwargs): | |||
""" | |||
Implements a simple singleton: multiple calls with the same name return the same instance. | |||
""" | |||
with cls._lock: | |||
if name not in cls._instances: | |||
instance = super().__new__(cls) | |||
cls._instances[name] = instance | |||
return cls._instances[name] | |||
def __init__( | |||
self, | |||
name: str = "root", | |||
level: int = logging.DEBUG, | |||
log_to_console: bool = True, | |||
console_level: int = logging.DEBUG, | |||
log_to_file: bool = False, | |||
log_file: str = "app.log", | |||
file_level: int = logging.INFO, | |||
max_bytes: int = 10 * 1024 * 1024, # 10 MB | |||
backup_count: int = 5, | |||
formatter: PrettyFormatter = None, | |||
json_output: bool = False, | |||
): | |||
""" | |||
Initialize (or re‐configure) a PrettyLogger. | |||
Args: | |||
name: logger name. | |||
level: root level for the logger (lowest level that will be processed). | |||
log_to_console: whether to attach a console (StreamHandler). | |||
console_level: level for console output. | |||
log_to_file: whether to attach a rotating file handler. | |||
log_file: path to the log file. | |||
file_level: level for file output. | |||
max_bytes: rotation threshold in bytes. | |||
backup_count: number of backup files to keep. | |||
formatter: an instance of PrettyFormatter. If None, a default is created. | |||
json_output: if True, forces JSON formatting on all handlers. | |||
""" | |||
# Avoid re‐initializing if already done | |||
logger = logging.getLogger(name) | |||
if getattr(logger, "_pretty_logger_inited", False): | |||
return | |||
self.logger = logger | |||
self.logger.setLevel(level) | |||
self.logger.propagate = False # avoid double‐logging if root logger is configured elsewhere | |||
# Create a default formatter if not supplied | |||
if formatter is None: | |||
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||
formatter = PrettyFormatter( | |||
fmt=fmt, use_colors=not json_output, json_output=json_output | |||
) | |||
# Add ContextFilter so extra_fields is always present | |||
context_filter = ContextFilter() | |||
self.logger.addFilter(context_filter) | |||
# Console handler | |||
if log_to_console: | |||
ch = logging.StreamHandler(sys.stdout) | |||
ch.setLevel(console_level) | |||
ch.setFormatter(formatter) | |||
self.logger.addHandler(ch) | |||
# Rotating file handler | |||
if log_to_file: | |||
fh = logging.handlers.RotatingFileHandler( | |||
filename=log_file, | |||
maxBytes=max_bytes, | |||
backupCount=backup_count, | |||
encoding="utf-8", | |||
) | |||
fh.setLevel(file_level) | |||
# For file, typically we don’t colorize | |||
file_formatter = ( | |||
JsonFormatter(datefmt="%Y-%m-%d %H:%M:%S") | |||
if json_output | |||
else PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||
) | |||
fh.setFormatter(file_formatter) | |||
self.logger.addHandler(fh) | |||
# Mark as initialized | |||
setattr(self.logger, "_pretty_logger_inited", True) | |||
def addHandler(self, handler: logging.Handler): | |||
""" | |||
Add a custom handler to the logger. | |||
This can be used to add any logging.Handler instance. | |||
""" | |||
self.logger.addHandler(handler) | |||
def bind(self, **kwargs): | |||
""" | |||
Return a LoggerAdapter that automatically injects extra fields into all log records. | |||
Example: | |||
ctx_logger = pretty_logger.bind(request_id=123, user="alice") | |||
ctx_logger.info("Something happened") # will include request_id and user in the output | |||
""" | |||
return logging.LoggerAdapter(self.logger, {"extra": kwargs}) | |||
def add_stream_handler( | |||
self, stream=None, level: int = logging.DEBUG, formatter: PrettyFormatter = None | |||
): | |||
""" | |||
Add an additional stream handler (e.g. to stderr). | |||
""" | |||
handler = logging.StreamHandler(stream or sys.stdout) | |||
handler.setLevel(level) | |||
if formatter is None: | |||
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||
formatter = PrettyFormatter(fmt=fmt) | |||
handler.setFormatter(formatter) | |||
self.logger.addHandler(handler) | |||
return handler | |||
def add_rotating_file_handler( | |||
self, | |||
log_file: str, | |||
level: int = logging.INFO, | |||
max_bytes: int = 10 * 1024 * 1024, | |||
backup_count: int = 5, | |||
json_output: bool = False, | |||
): | |||
""" | |||
Add another rotating file handler at runtime. | |||
""" | |||
fh = logging.handlers.RotatingFileHandler( | |||
filename=log_file, | |||
maxBytes=max_bytes, | |||
backupCount=backup_count, | |||
encoding="utf-8", | |||
) | |||
fh.setLevel(level) | |||
if json_output: | |||
formatter = JsonFormatter() | |||
else: | |||
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s" | |||
formatter = PrettyFormatter(fmt=fmt, use_colors=False, json_output=False) | |||
fh.setFormatter(formatter) | |||
self.logger.addHandler(fh) | |||
return fh | |||
def remove_handler(self, handler: logging.Handler): | |||
""" | |||
Remove a handler from the logger. | |||
""" | |||
self.logger.removeHandler(handler) | |||
def set_level(self, level: int): | |||
""" | |||
Change the logger’s root level at runtime. | |||
""" | |||
self.logger.setLevel(level) | |||
def get_logger(self) -> logging.Logger: | |||
return self.logger | |||
# Convenience methods to expose common logging calls directly from this object: | |||
def debug(self, msg, *args, **kwargs): | |||
self.logger.debug(msg, *args, **kwargs) | |||
def info(self, msg, *args, **kwargs): | |||
self.logger.info(msg, *args, **kwargs) | |||
def warning(self, msg, *args, **kwargs): | |||
self.logger.warning(msg, *args, **kwargs) | |||
def error(self, msg, *args, **kwargs): | |||
self.logger.error(msg, *args, **kwargs) | |||
def critical(self, msg, *args, **kwargs): | |||
self.logger.critical(msg, *args, **kwargs) | |||
def exception(self, msg, *args, exc_info=True, **kwargs): | |||
""" | |||
Log an error message with exception traceback. By default exc_info=True. | |||
""" | |||
self.logger.error(msg, *args, exc_info=exc_info, **kwargs) | |||
# Convenience function for a module‐level “global” pretty logger: | |||
_global_loggers = {} | |||
_global_lock = threading.Lock() | |||
def get_pretty_logger( | |||
name: str = "root", | |||
level: int = logging.DEBUG, | |||
log_to_console: bool = True, | |||
console_level: int = logging.DEBUG, | |||
log_to_file: bool = False, | |||
log_file: str = "app.log", | |||
file_level: int = logging.INFO, | |||
max_bytes: int = 10 * 1024 * 1024, | |||
backup_count: int = 5, | |||
json_output: bool = False, | |||
) -> PrettyLogger: | |||
""" | |||
Return a PrettyLogger instance for `name`, creating/configuring it on first use. | |||
Subsequent calls with the same `name` will return the same logger (singleton‐style). | |||
""" | |||
with _global_lock: | |||
if name not in _global_loggers: | |||
pl = PrettyLogger( | |||
name=name, | |||
level=level, | |||
log_to_console=log_to_console, | |||
console_level=console_level, | |||
log_to_file=log_to_file, | |||
log_file=log_file, | |||
file_level=file_level, | |||
max_bytes=max_bytes, | |||
backup_count=backup_count, | |||
json_output=json_output, | |||
) | |||
_global_loggers[name] = pl | |||
return _global_loggers[name] |
@@ -0,0 +1,66 @@ | |||
import torch | |||
import numpy as np | |||
from torch.utils.data.sampler import Sampler | |||
def negative_sampling( | |||
pos: torch.Tensor, | |||
num_entities: int, | |||
num_negatives: int = 1, | |||
mode: list = ["head", "tail"], | |||
) -> torch.Tensor: | |||
pos = pos.repeat_interleave(num_negatives, dim=0) | |||
if "head" in mode: | |||
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||
pos[:, 0] = neg | |||
if "tail" in mode: | |||
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device) | |||
pos[:, 2] = neg | |||
return pos | |||
class DeterministicRandomSampler(Sampler): | |||
""" | |||
A deterministic random sampler that selects a fixed number of samples | |||
in a reproducible manner using a given seed. | |||
""" | |||
def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None: | |||
""" | |||
Args: | |||
dataset (Dataset): PyTorch dataset. | |||
sample_size (int, optional): Number of samples to draw. If None, use full dataset. | |||
seed (int): Seed for reproducibility. | |||
""" | |||
self.dataset_size = dataset_size | |||
self.seed = seed | |||
self.sample_size = sample_size if sample_size != 0 else self.dataset_size | |||
# Ensure sample size is within dataset size | |||
if self.sample_size > self.dataset_size: | |||
msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}." | |||
raise ValueError( | |||
msg, | |||
) | |||
self.indices = self._generate_deterministic_indices() | |||
def _generate_deterministic_indices(self) -> list[int]: | |||
""" | |||
Generates a fixed random subset of indices using the given seed. | |||
""" | |||
rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility | |||
all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices | |||
return all_indices[: self.sample_size].tolist() # Select only the desired number of samples | |||
def __iter__(self) -> iter: | |||
""" | |||
Yields the shuffled dataset indices. | |||
""" | |||
return iter(self.indices) | |||
def __len__(self) -> int: | |||
""" | |||
Returns the total number of samples drawn. | |||
""" | |||
return self.sample_size |
@@ -0,0 +1,21 @@ | |||
from __future__ import annotations | |||
import logging | |||
from typing import TYPE_CHECKING | |||
if TYPE_CHECKING: | |||
from torch.utils.tensorboard import SummaryWriter | |||
class TensorBoardHandler(logging.Handler): | |||
"""Convert each log record into a TensorBoard text summary.""" | |||
def __init__(self, writer: SummaryWriter, tag: str = "logs"): | |||
super().__init__(level=logging.INFO) | |||
self.writer = writer | |||
self.tag = tag | |||
self._step = 0 # will be incremented every emit | |||
def emit(self, record: logging.LogRecord) -> None: | |||
self.writer.add_text(self.tag, self.format(record), global_step=self._step) | |||
self._step += 1 |
@@ -0,0 +1,12 @@ | |||
from torch.optim.lr_scheduler import StepLR | |||
class StepMinLR(StepLR): | |||
def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1): | |||
self.min_lr = min_lr | |||
super().__init__(optimizer, step_size, gamma, last_epoch) | |||
def get_lr(self): | |||
# use StepLR's formula, then clamp | |||
raw_lrs = super().get_lr() | |||
return [max(lr, self.min_lr) for lr in raw_lrs] |
@@ -0,0 +1,25 @@ | |||
import random | |||
from enum import Enum | |||
import numpy as np | |||
import torch | |||
def set_seed(seed: int = 42) -> None: | |||
""" | |||
Set the random seed for reproducibility. | |||
Args: | |||
seed (int): The seed value to set for random number generation. | |||
""" | |||
torch.manual_seed(seed) | |||
random.seed(seed) | |||
np.random.default_rng(seed) | |||
if torch.cuda.is_available(): | |||
torch.cuda.manual_seed_all(seed) | |||
class Target(str, Enum): # ↔ PyKEEN's LABEL_* constants | |||
HEAD = "head" | |||
TAIL = "tail" | |||
RELATION = "relation" |
@@ -0,0 +1,85 @@ | |||
from __future__ import annotations | |||
from abc import ABC, abstractmethod | |||
from typing import TYPE_CHECKING | |||
import torch | |||
from torch import nn | |||
if TYPE_CHECKING: | |||
from torch.optim.lr_scheduler import LRScheduler | |||
class ModelTrainerBase(ABC): | |||
def __init__( | |||
self, | |||
model: nn.Module, | |||
optimizer: torch.optim.Optimizer, | |||
scheduler: LRScheduler | None = None, | |||
device: torch.device = "cpu", | |||
) -> None: | |||
self.model = model.to(device) | |||
self.optimizer = optimizer | |||
self.scheduler = scheduler | |||
self.device = device | |||
def save(self, save_dpath: str, step: int) -> None: | |||
""" | |||
Save the model and optimizer state to the specified directory. | |||
Args: | |||
save_dpath (str): The directory path where the model and optimizer state will be saved. | |||
""" | |||
data = { | |||
"model": self.model.state_dict(), | |||
"optimizer": self.optimizer.state_dict(), | |||
"step": step, | |||
} | |||
if self.scheduler is not None: | |||
data["scheduler"] = self.scheduler.state_dict() | |||
torch.save( | |||
data, | |||
f"{save_dpath}/model-{step}.pt", | |||
) | |||
def load(self, load_fpath: str) -> None: | |||
""" | |||
Load the model and optimizer state from the specified directory. | |||
Args: | |||
load_dpath (str): The directory path from which the model and | |||
optimizer state will be loaded. | |||
""" | |||
data = torch.load(load_fpath, map_location="cpu") | |||
model_state_dict = data["model"] | |||
self.model.load_state_dict(model_state_dict) # move everything to the correct device | |||
self.model.to(self.device) | |||
optimizer_state_dict = data["optimizer"] | |||
self.optimizer.load_state_dict(optimizer_state_dict) | |||
if self.scheduler is not None: | |||
scheduler_data = data.get("scheduler", {}) | |||
self.scheduler.load_state_dict(scheduler_data) | |||
return data["step"] | |||
@abstractmethod | |||
def _loss(self, batch: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Compute the loss for a batch of data. | |||
Args: | |||
batch (torch.Tensor): A tensor containing a batch of data. | |||
Returns: | |||
torch.Tensor: The computed loss for the batch. | |||
""" | |||
... | |||
@abstractmethod | |||
def train_step(self, batch: torch.Tensor) -> float: ... | |||
@abstractmethod | |||
def eval_step(self, batch: torch.Tensor) -> float: ... |
@@ -0,0 +1,148 @@ | |||
import torch | |||
import torch.nn.functional as F | |||
from einops import repeat | |||
from pykeen.sampling import BernoulliNegativeSampler | |||
from pykeen.training import SLCWATrainingLoop | |||
from pykeen.triples.instances import SLCWABatch | |||
from torch.optim import Adam | |||
from torch.optim.lr_scheduler import StepLR | |||
from torch.utils.data import DataLoader | |||
from metrics.ranking import simple_ranking | |||
from models.translation.trans_e import TransE | |||
from tools.params import TrainingParams | |||
from training.model_trainers.model_trainer_base import ModelTrainerBase | |||
class TransETrainer(ModelTrainerBase): | |||
def __init__( | |||
self, | |||
model: TransE, | |||
training_params: TrainingParams, | |||
triples_factory: torch.Tensor, | |||
mapped_triples: torch.Tensor, | |||
device: torch.device, | |||
margin: float = 1.0, | |||
n_negative: int = 1, | |||
loss_fn: str = "margin", | |||
**kwargs, | |||
) -> None: | |||
optimizer = Adam( | |||
model.inner_model.parameters(), | |||
lr=training_params.lr, | |||
weight_decay=training_params.weight_decay, | |||
) | |||
# scheduler = None # Placeholder for scheduler, can be implemented later | |||
scheduler = StepLR( | |||
optimizer, | |||
step_size=training_params.lr_step, | |||
gamma=training_params.gamma, | |||
) | |||
super().__init__(model, optimizer, scheduler, device) | |||
self.triples_factory = triples_factory | |||
self.margin = margin | |||
self.n_negative = n_negative | |||
self.loss_fn = loss_fn | |||
self.negative_sampler = BernoulliNegativeSampler( | |||
mapped_triples=mapped_triples.to(device), | |||
num_entities=model.num_entities, | |||
num_relations=model.num_relations, | |||
num_negs_per_pos=n_negative, | |||
).to(device) | |||
self.training_loop = SLCWATrainingLoop( | |||
model=model.inner_model, | |||
triples_factory=triples_factory, | |||
optimizer=optimizer, | |||
negative_sampler=self.negative_sampler, | |||
lr_scheduler=scheduler, | |||
) | |||
def create_data_loader( | |||
self, | |||
triples_factory: torch.Tensor, | |||
batch_size: int, | |||
shuffle: bool = True, | |||
) -> DataLoader: | |||
triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||
return self.training_loop._create_training_data_loader( | |||
triples_factory=triples_factory, | |||
sampler=None, | |||
batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=False, | |||
) | |||
def _loss(self, pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor: | |||
""" | |||
Compute the loss based on the positive and negative scores. | |||
Args: | |||
pos_scores (torch.Tensor): Scores for positive triples. | |||
neg_scores (torch.Tensor): Scores for negative triples. | |||
Returns: | |||
torch.Tensor: The computed loss. | |||
""" | |||
k = neg_scores.shape[1] | |||
if self.loss_fn == "margin": | |||
target = -torch.ones_like(neg_scores) # want s_pos < s_neg | |||
loss = F.margin_ranking_loss( | |||
repeat(pos_scores, "b -> b k", k=k), | |||
neg_scores, | |||
target, | |||
margin=self.margin, | |||
) | |||
return loss | |||
if self.loss_fn == "adversarial": | |||
# b = pos_scores.size(0) | |||
# neg_scores = rearrange(neg_scores, "(b k) -> b k", b=b, k=k) | |||
# importance weights (detach so gradients flow only through log-sigmoid terms) | |||
w = F.softmax(self.adv_temp * neg_scores, dim=1).detach() | |||
pos_term = -F.logsigmoid(self.margin - pos_scores) # (B,) | |||
neg_term = -(w * F.logsigmoid(neg_scores - self.margin)).sum(dim=1) # (B,) | |||
loss = (pos_term + neg_term).mean() | |||
return loss | |||
return None | |||
def train_step(self, batch: SLCWABatch, step: int) -> float: | |||
""" | |||
Perform a training step on the given batch of data. | |||
Args: | |||
batch (torch.Tensor): A tensor containing a batch of triples. | |||
Returns: | |||
float: The computed loss for the batch. | |||
""" | |||
continue_training = step > 0 | |||
loss = self.training_loop.train( | |||
triples_factory=self.triples_factory, | |||
num_epochs=step + 1, | |||
batch_size=self.training_loop._get_batch_size(batch), | |||
continue_training=continue_training, | |||
use_tqdm=False, | |||
use_tqdm_batch=False, | |||
label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||
# num_workers=3, | |||
)[-1] | |||
return loss | |||
def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
"""Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||
self.model.eval() | |||
with torch.no_grad(): | |||
return simple_ranking(self.model, batch.to(self.device)) |
@@ -0,0 +1,121 @@ | |||
import torch | |||
from pykeen.losses import MarginRankingLoss | |||
from pykeen.sampling import BernoulliNegativeSampler | |||
from pykeen.training import SLCWATrainingLoop | |||
from pykeen.triples.instances import SLCWABatch | |||
from torch.optim import Adam | |||
from metrics.ranking import simple_ranking | |||
from models.translation.trans_h import TransH | |||
from tools.params import TrainingParams | |||
from tools.train import StepMinLR | |||
from training.model_trainers.model_trainer_base import ModelTrainerBase | |||
class TransHTrainer(ModelTrainerBase): | |||
def __init__( | |||
self, | |||
model: TransH, | |||
training_params: TrainingParams, | |||
triples_factory: torch.Tensor, | |||
mapped_triples: torch.Tensor, | |||
device: torch.device, | |||
margin: float = 1.0, | |||
n_negative: int = 1, | |||
regul_rate: float = 0.0, | |||
loss_fn: str = "margin", | |||
**kwargs, | |||
) -> None: | |||
optimizer = Adam( | |||
model.inner_model.parameters(), | |||
lr=training_params.lr, | |||
weight_decay=training_params.weight_decay, | |||
) | |||
scheduler = StepMinLR( | |||
optimizer, | |||
step_size=training_params.lr_step, | |||
gamma=training_params.gamma, | |||
min_lr=training_params.min_lr, | |||
) | |||
super().__init__(model, optimizer, scheduler, device) | |||
self.triples_factory = triples_factory | |||
self.margin = margin | |||
self.n_negative = n_negative | |||
self.regul_rate = regul_rate | |||
self.loss_fn = loss_fn | |||
self.negative_sampler = BernoulliNegativeSampler( | |||
mapped_triples=mapped_triples.to(device), | |||
num_entities=model.num_entities, | |||
num_relations=model.num_relations, | |||
num_negs_per_pos=n_negative, | |||
).to(device) | |||
self.training_loop = SLCWATrainingLoop( | |||
model=model.inner_model, | |||
triples_factory=triples_factory, | |||
optimizer=optimizer, | |||
negative_sampler=self.negative_sampler, | |||
lr_scheduler=scheduler, | |||
) | |||
def _loss(self) -> torch.Tensor: | |||
return self.model.loss | |||
def create_data_loader( | |||
self, | |||
triples_factory: torch.Tensor, | |||
batch_size: int, | |||
shuffle: bool = True, | |||
) -> torch.utils.data.DataLoader: | |||
triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||
return self.training_loop._create_training_data_loader( | |||
triples_factory=triples_factory, | |||
sampler=None, | |||
batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=False, | |||
) | |||
@property | |||
def loss(self) -> torch.Tensor: | |||
return MarginRankingLoss(margin=self.margin, reduction="mean") | |||
def train_step( | |||
self, | |||
batch: SLCWABatch, | |||
step: int, | |||
) -> float: | |||
""" | |||
Perform a training step on the given batch of data. | |||
Args: | |||
batch (torch.Tensor): A tensor containing a batch of triples. | |||
Returns: | |||
float: The computed loss for the batch. | |||
""" | |||
continue_training = step > 0 | |||
loss = self.training_loop.train( | |||
triples_factory=self.triples_factory, | |||
num_epochs=step + 1, | |||
batch_size=self.training_loop._get_batch_size(batch), | |||
continue_training=continue_training, | |||
use_tqdm=False, | |||
use_tqdm_batch=False, | |||
label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||
)[-1] | |||
return loss | |||
def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
"""Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||
self.model.eval() | |||
with torch.no_grad(): | |||
return simple_ranking(self.model, batch.to(self.device)) |
@@ -0,0 +1,116 @@ | |||
import torch | |||
from pykeen.sampling import BernoulliNegativeSampler | |||
from pykeen.training import SLCWATrainingLoop | |||
from pykeen.triples.instances import SLCWABatch | |||
from torch.optim import Adam | |||
from metrics.ranking import simple_ranking | |||
from models.translation.trans_r import TransR | |||
from tools.params import TrainingParams | |||
from tools.train import StepMinLR | |||
from training.model_trainers.model_trainer_base import ModelTrainerBase | |||
class TransRTrainer(ModelTrainerBase): | |||
def __init__( | |||
self, | |||
model: TransR, | |||
training_params: TrainingParams, | |||
triples_factory: torch.Tensor, | |||
mapped_triples: torch.Tensor, | |||
device: torch.device, | |||
margin: float = 1.0, | |||
n_negative: int = 1, | |||
regul_rate: float = 0.0, | |||
loss_fn: str = "margin", | |||
**kwargs, | |||
) -> None: | |||
optimizer = Adam( | |||
model.inner_model.parameters(), | |||
lr=training_params.lr, | |||
weight_decay=training_params.weight_decay, | |||
) | |||
scheduler = StepMinLR( | |||
optimizer, | |||
step_size=training_params.lr_step, | |||
gamma=training_params.gamma, | |||
min_lr=training_params.min_lr, | |||
) | |||
super().__init__(model, optimizer, scheduler, device) | |||
self.triples_factory = triples_factory | |||
self.margin = margin | |||
self.n_negative = n_negative | |||
self.regul_rate = regul_rate | |||
self.loss_fn = loss_fn | |||
self.negative_sampler = BernoulliNegativeSampler( | |||
mapped_triples=mapped_triples.to(device), | |||
num_entities=model.num_entities, | |||
num_relations=model.num_relations, | |||
num_negs_per_pos=n_negative, | |||
).to(device) | |||
self.training_loop = SLCWATrainingLoop( | |||
model=model.inner_model, | |||
triples_factory=triples_factory, | |||
optimizer=optimizer, | |||
negative_sampler=self.negative_sampler, | |||
lr_scheduler=scheduler, | |||
) | |||
def create_data_loader( | |||
self, | |||
triples_factory: torch.Tensor, | |||
batch_size: int, | |||
shuffle: bool = True, | |||
) -> torch.utils.data.DataLoader: | |||
triples_factory = self.triples_factory if triples_factory is None else triples_factory | |||
return self.training_loop._create_training_data_loader( | |||
triples_factory=triples_factory, | |||
sampler=None, | |||
batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=False, | |||
) | |||
def _loss(self) -> torch.Tensor: | |||
return self.model.loss | |||
def train_step( | |||
self, | |||
batch: SLCWABatch, | |||
step: int, | |||
) -> float: | |||
""" | |||
Perform a training step on the given batch of data. | |||
Args: | |||
batch (torch.Tensor): A tensor containing a batch of triples. | |||
Returns: | |||
float: The computed loss for the batch. | |||
""" | |||
continue_training = step > 0 | |||
loss = self.training_loop.train( | |||
triples_factory=self.triples_factory, | |||
num_epochs=step + 1, | |||
batch_size=self.training_loop._get_batch_size(batch), | |||
continue_training=continue_training, | |||
use_tqdm=False, | |||
use_tqdm_batch=False, | |||
label_smoothing=0.0, # Assuming no label smoothing for simplicity | |||
)[-1] | |||
return loss | |||
def eval_step(self, batch: torch.Tensor) -> torch.Tensor: | |||
"""Return rank tensor for this batch (needed for MRR / Hits@K).""" | |||
self.model.eval() | |||
with torch.no_grad(): | |||
return simple_ranking(self.model, batch.to(self.device)) |
@@ -0,0 +1,214 @@ | |||
from __future__ import annotations | |||
from pathlib import Path | |||
from typing import TYPE_CHECKING, Any | |||
import torch | |||
from pykeen.evaluation import RankBasedEvaluator | |||
from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank | |||
from torch.utils.data import DataLoader | |||
from torch.utils.tensorboard import SummaryWriter | |||
from tools import ( | |||
CheckpointManager, | |||
DeterministicRandomSampler, | |||
TensorBoardHandler, | |||
get_pretty_logger, | |||
set_seed, | |||
) | |||
if TYPE_CHECKING: | |||
from data.kg_dataset import KGDataset | |||
from tools import CommonParams, TrainingParams | |||
from .model_trainers.model_trainer_base import ModelTrainerBase | |||
logger = get_pretty_logger(__name__) | |||
cpu_device = torch.device("cpu") | |||
class Trainer: | |||
def __init__( | |||
self, | |||
train_dataset: KGDataset, | |||
val_dataset: KGDataset | None, | |||
model_trainer: ModelTrainerBase, | |||
common_params: CommonParams, | |||
training_params: TrainingParams, | |||
device: torch.device = cpu_device, | |||
) -> None: | |||
set_seed(training_params.seed) | |||
self.train_dataset = train_dataset | |||
self.val_dataset = val_dataset if val_dataset is not None else train_dataset | |||
self.model_trainer = model_trainer | |||
self.common_params = common_params | |||
self.training_params = training_params | |||
self.device = device | |||
self.train_loader = self.model_trainer.create_data_loader( | |||
triples_factory=self.train_dataset.triples_factory, | |||
batch_size=training_params.batch_size, | |||
shuffle=True, | |||
) | |||
self.train_iterator = self._data_iterator(self.train_loader) | |||
seed = 1234 | |||
val_sampler = DeterministicRandomSampler( | |||
len(self.val_dataset), | |||
sample_size=self.training_params.validation_sample_size, | |||
seed=seed, | |||
) | |||
self.valid_loader = DataLoader( | |||
self.val_dataset, | |||
batch_size=training_params.validation_batch_size, | |||
shuffle=False, | |||
num_workers=training_params.num_workers, | |||
sampler=val_sampler, | |||
) | |||
self.valid_iterator = self._data_iterator(self.valid_loader) | |||
self.step = 0 | |||
self.log_every = common_params.log_every | |||
self.log_console_every = common_params.log_console_every | |||
self.save_dpath = common_params.save_dpath | |||
self.save_every = common_params.save_every | |||
self.checkpoint_manager = CheckpointManager( | |||
root_directory=self.save_dpath, | |||
run_name=common_params.run_name, | |||
load_only=common_params.evaluate_only, | |||
) | |||
self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory | |||
if not common_params.evaluate_only: | |||
logger.info(f"Saving checkpoints to {self.ckpt_dpath}") | |||
if self.common_params.load_path: | |||
self.load() | |||
self.writer = None | |||
if not common_params.evaluate_only: | |||
run_name = self.checkpoint_manager.run_name | |||
self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}") | |||
logger.addHandler(TensorBoardHandler(self.writer)) | |||
def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None: | |||
""" | |||
Log a message to the console and TensorBoard. | |||
Args: | |||
message (str): The message to log. | |||
""" | |||
if console_msg is not None: | |||
logger.info(f"{console_msg}") | |||
if self.writer is not None: | |||
self.writer.add_scalar(tag, log_value, self.step) | |||
def save(self) -> None: | |||
save_dpath = self.ckpt_dpath | |||
save_dpath = Path(save_dpath) | |||
save_dpath.mkdir(parents=True, exist_ok=True) | |||
self.model_trainer.save(save_dpath, self.step) | |||
def load(self) -> None: | |||
load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path) | |||
load_fpath = Path(load_path) | |||
if load_fpath.exists(): | |||
self.step = self.model_trainer.load(load_fpath) | |||
logger.info(f"Model loaded from {load_fpath}.") | |||
else: | |||
msg = f"Load path {load_fpath} does not exist." | |||
raise FileNotFoundError(msg) | |||
def _data_iterator(self, loader: DataLoader) -> iter: | |||
"""Yield batches of positive and negative triples.""" | |||
while True: | |||
yield from loader | |||
def train(self) -> None: | |||
"""Run `n_steps` optimisation steps; return mean loss.""" | |||
data_batch = next(self.train_iterator) | |||
loss = self.model_trainer.train_step(data_batch, self.step) | |||
if self.step % self.log_console_every == 0: | |||
logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}") | |||
if self.step % self.log_every == 0: | |||
self.log("train/loss", loss) | |||
self.step += 1 | |||
def evaluate(self) -> dict[str, torch.Tensor]: | |||
ks = [1, 3, 10] | |||
metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks] | |||
evaluator = RankBasedEvaluator( | |||
metrics=metrics, | |||
filtered=True, | |||
batch_size=self.training_params.validation_batch_size, | |||
) | |||
result = evaluator.evaluate( | |||
model=self.model_trainer.model, | |||
mapped_triples=self.val_dataset.triples, | |||
additional_filter_triples=[ | |||
self.train_dataset.triples, | |||
self.val_dataset.triples, | |||
], | |||
batch_size=self.training_params.validation_batch_size, | |||
) | |||
mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank") | |||
hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks} | |||
metrics = { | |||
"mrr": mrr, | |||
**hits_at, | |||
} | |||
for k, v in metrics.items(): | |||
self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}") | |||
# ───────────────────────── # | |||
# master schedule | |||
# ───────────────────────── # | |||
def run(self) -> None: | |||
""" | |||
Alternate training and evaluation until `total_steps` | |||
optimisation steps have been executed. | |||
""" | |||
total_steps = self.training_params.num_train_steps | |||
eval_every = self.training_params.eval_every | |||
while self.step < total_steps: | |||
self.train() | |||
if self.step % eval_every == 0: | |||
logger.info(f"Evaluating at step {self.step}...") | |||
self.evaluate() | |||
if self.step % self.save_every == 0: | |||
logger.info(f"Saving model at step {self.step}...") | |||
self.save() | |||
logger.info("Training complete.") | |||
def reset(self) -> None: | |||
""" | |||
Reset the trainer state, including the model trainer and data iterators. | |||
""" | |||
self.model_trainer.reset() | |||
self.train_iterator = self._data_iterator(self.train_loader) | |||
self.valid_iterator = self._data_iterator(self.valid_loader) | |||
self.step = 0 | |||
logger.info("Trainer state has been reset.") |