Browse Source

Added code base.

main
Naser Kazemi 3 weeks ago
commit
02a34881d1
71 changed files with 4858 additions and 0 deletions
  1. 1
    0
      README.md
  2. 106
    0
      build_crec_datasets.py
  3. 20
    0
      configs/common/common.yaml
  4. 9
    0
      configs/config.yaml
  5. 3
    0
      configs/data/data.yaml
  6. 8
    0
      configs/data/fb15k.yaml
  7. 8
    0
      configs/data/wn18.yaml
  8. 8
    0
      configs/data/wn18rr.yaml
  9. 9
    0
      configs/data/yago3_10.yaml
  10. 3
    0
      configs/model/model.yaml
  11. 9
    0
      configs/model/trans_e.yaml
  12. 11
    0
      configs/model/trans_h.yaml
  13. 12
    0
      configs/model/trans_r.yaml
  14. 21
    0
      configs/training/training.yaml
  15. 11
    0
      configs/training/trans_e_trainer.yaml
  16. 11
    0
      configs/training/trans_r_trainer.yaml
  17. 17
    0
      configs/trans_e_fb15k.yaml
  18. 17
    0
      configs/trans_e_wn18.yaml
  19. 18
    0
      configs/trans_e_wn18rr.yaml
  20. 17
    0
      configs/trans_e_yago3_10.yaml
  21. 17
    0
      configs/trans_r.yaml
  22. 6
    0
      data/__init__.py
  23. 49
    0
      data/fb15k.py
  24. 57
    0
      data/hationet.py
  25. 89
    0
      data/kg_dataset.py
  26. 57
    0
      data/open_bio_link.py
  27. 57
    0
      data/openke_wiki.py
  28. 112
    0
      data/wn18.py
  29. 70
    0
      data/yago3_10.py
  30. 38
    0
      eval_datasets.py
  31. 80
    0
      main.py
  32. 0
    0
      metrics/__init__.py
  33. 55
    0
      metrics/base_metric.py
  34. 264
    0
      metrics/c_swklf.py
  35. 732
    0
      metrics/crec_modifier.py
  36. 106
    0
      metrics/crec_radius_sample.py
  37. 226
    0
      metrics/greedy_crec.py
  38. 30
    0
      metrics/ranking.py
  39. 237
    0
      metrics/wlcrec.py
  40. 135
    0
      metrics/wlec.py
  41. 0
    0
      models/__init__.py
  42. 127
    0
      models/base_model.py
  43. 0
    0
      models/translation/__init__.py
  44. BIN
      models/translation/__pycache__/__init__.cpython-310.pyc
  45. BIN
      models/translation/__pycache__/trans_e.cpython-310.pyc
  46. 70
    0
      models/translation/trans_e.py
  47. 76
    0
      models/translation/trans_h.py
  48. 79
    0
      models/translation/trans_r.py
  49. 82
    0
      pyproject.toml
  50. 28
    0
      setup/install.sh
  51. 163
    0
      setup/kg_env.yaml
  52. 6
    0
      tools/__init__.py
  53. 221
    0
      tools/checkpoint_manager.py
  54. 45
    0
      tools/params.py
  55. 417
    0
      tools/pretty_logger.py
  56. 66
    0
      tools/sampling.py
  57. 21
    0
      tools/tb_handler.py
  58. 12
    0
      tools/train.py
  59. 25
    0
      tools/utils.py
  60. 0
    0
      training/__init__.py
  61. 0
    0
      training/model_trainers/__init__.py
  62. BIN
      training/model_trainers/__pycache__/__init__.cpython-310.pyc
  63. BIN
      training/model_trainers/__pycache__/model_trainer_base.cpython-310.pyc
  64. 85
    0
      training/model_trainers/model_trainer_base.py
  65. 0
    0
      training/model_trainers/translation/__init__.py
  66. BIN
      training/model_trainers/translation/__pycache__/__init__.cpython-310.pyc
  67. BIN
      training/model_trainers/translation/__pycache__/trans_e_trainer.cpython-310.pyc
  68. 148
    0
      training/model_trainers/translation/trans_e_trainer.py
  69. 121
    0
      training/model_trainers/translation/trans_h_trainer.py
  70. 116
    0
      training/model_trainers/translation/trans_r_trainer.py
  71. 214
    0
      training/trainer.py

+ 1
- 0
README.md View File

@@ -0,0 +1 @@
# KGEvaluation

+ 106
- 0
build_crec_datasets.py View File

@@ -0,0 +1,106 @@
from __future__ import annotations

import argparse
import logging

import torch

from data import YAGO310Dataset

# from metrics.crec_modifier import tune_crec_parallel_edits
from data.kg_dataset import KGDataset
from metrics.crec_radius_sample import find_sample
from metrics.wlcrec import WLCREC
from tools import get_pretty_logger

logging = get_pretty_logger(__name__)


# --------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------
def main():
p = argparse.ArgumentParser()
p.add_argument("--depth", type=int, default=5)
p.add_argument("--lower", type=float, default=0.25)
p.add_argument("--upper", type=float, default=0.26)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--save", type=str, default="fb15k_crec_bi.pt")
args = p.parse_args()

# ds = FB15KDataset()
ds = YAGO310Dataset()
triples = ds.triples
n_ent, n_rel = ds.num_entities, ds.num_relations
logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples))

# wlcrec = WLCREC(ds)
# c = wlcrec.compute(5)[-2] # C_NWLEC
# colours = wlcrec.wl_colours(args.depth)[-1]

# tuned = tune_crec(
# wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed
# )

# tuned = tune_crec_parallel_edits(
# triples,
# n_ent,
# n_rel,
# depth=args.depth,
# lo=args.lower,
# hi=args.upper,
# seed=args.seed,
# )

# tuned = fast_tune_crec(
# triples.tolist(),
# colours,
# n_ent,
# n_rel,
# depth=args.depth,
# c=c,
# lo=args.lower,
# hi=args.upper,
# max_iters=1000,
# max_workers=10, # Adjust number of workers as needed
# )

tuned = find_sample(
triples.tolist(),
n_ent,
n_rel,
depth=args.depth,
ratio=0.1, # Adjust ratio as needed
radius=2, # Adjust radius as needed
lower=args.lower,
upper=args.upper,
max_tries=1000,
seed=58,
)

tuned_ds = KGDataset(
triples_factory=None,
triples=torch.tensor(tuned, dtype=torch.long),
num_entities=n_ent,
num_relations=n_rel,
split="train",
)

tuned_wlcrec = WLCREC(tuned_ds)
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5)
print(
f"\nTuned CREC results (H={5})",
f"\n • avg WLEC : {entropy:.6f} nats",
f"\n • C_ratio : {c_ratio:.6f}",
f"\n • C_NWLEC : {c_nwlec:.6f}",
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats",
f"\n • D_ratio : {d_ratio:.6f}",
f"\n • D_NWLEC : {d_nwlec:.6f}",
)

torch.save(torch.tensor(tuned, dtype=torch.long), args.save)
logging.info("Saved tuned triples → %s", args.save)


if __name__ == "__main__":
main()

+ 20
- 0
configs/common/common.yaml View File

@@ -0,0 +1,20 @@
_target_: tools.params.CommonParams

model_name: "default_model"
project_name: "my_project"
run_name: ...

# Where to write all outputs (checkpoints / tensorboard logs / etc.)
save_dpath: "./checkpoints"
save_every: 200

# If you want to resume from a checkpoint, set load_dpath to a string path;
# otherwise leave it null or empty.
load_path: null

# Default log directory
log_dir: "logs"
log_every: 10
log_console_every: 10

evaluate_only: false

+ 9
- 0
configs/config.yaml View File

@@ -0,0 +1,9 @@
defaults:
- common: common
- model: model
- data: data
- training: training

hydra:
run:
dir: "./outputs"

+ 3
- 0
configs/data/data.yaml View File

@@ -0,0 +1,3 @@
_target_: data.wn18rr_dataset.WN18RRDataset

split: train

+ 8
- 0
configs/data/fb15k.yaml View File

@@ -0,0 +1,8 @@
train:
_target_: data.fb15k.FB15KDataset
split: train


valid:
_target_: data.fb15k.FB15KDataset
split: valid

+ 8
- 0
configs/data/wn18.yaml View File

@@ -0,0 +1,8 @@
train:
_target_: data.wn18.WN18Dataset
split: train


valid:
_target_: data.wn18.WN18Dataset
split: valid

+ 8
- 0
configs/data/wn18rr.yaml View File

@@ -0,0 +1,8 @@
train:
_target_: data.wn18.WN18RRDataset
split: train


valid:
_target_: data.wn18.WN18RRDataset
split: valid

+ 9
- 0
configs/data/yago3_10.yaml View File

@@ -0,0 +1,9 @@
train:
_target_: data.yago3_10.YAGO310Dataset
split: train


valid:
_target_: data.yago3_10.YAGO310Dataset
split: valid


+ 3
- 0
configs/model/model.yaml View File

@@ -0,0 +1,3 @@
num_entities: 100
num_relations: 100
dim: 100

+ 9
- 0
configs/model/trans_e.yaml View File

@@ -0,0 +1,9 @@
defaults:
- model
- _self_

_target_: models.translation.trans_e.TransE

dim: 200
sf_norm: 1
p_norm: true

+ 11
- 0
configs/model/trans_h.yaml View File

@@ -0,0 +1,11 @@
defaults:
- model
- _self_

_target_: models.translation.trans_h.TransH

dim: 200
sf_norm: 1
p_norm: true
p_norm_value: 2
margin: 1.0

+ 12
- 0
configs/model/trans_r.yaml View File

@@ -0,0 +1,12 @@
defaults:
- model
- _self_

_target_: models.translation.trans_r.TransR

entity_dim: 200
relation_dim: 30
sf_norm: 1
p_norm: true
p_norm_value: 2
margin: 1.0

+ 21
- 0
configs/training/training.yaml View File

@@ -0,0 +1,21 @@
_target_: tools.params.TrainingParams

lr: 0.005
min_lr: 0.0001
weight_decay: 0.0
t0: 50
lr_step: 10
gamma: 0.95
batch_size: 8192
num_workers: 3
seed: 42

num_train_steps: 10000
eval_every: 100

validation_sample_size: 1000
validation_batch_size: 64


model_trainer:
_target_: training.model_trainers.model_trainer_base.ModelTrainerBase

+ 11
- 0
configs/training/trans_e_trainer.yaml View File

@@ -0,0 +1,11 @@
defaults:
- training
- _self_


model_trainer:
_target_: training.model_trainers.translation.trans_e_trainer.TransETrainer
margin: 1.0
n_negative: 5
regul_rate: 0.0
loss_fn: "margin"

+ 11
- 0
configs/training/trans_r_trainer.yaml View File

@@ -0,0 +1,11 @@
defaults:
- training
- _self_


model_trainer:
_target_: training.model_trainers.translation.trans_r_trainer.TransRTrainer
margin: 1.0
n_negative: 10
regul_rate: 0.0
loss_fn: "margin"

+ 17
- 0
configs/trans_e_fb15k.yaml View File

@@ -0,0 +1,17 @@
defaults:
- config
- override common: common
- override data: fb15k
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransE_FB15K"
evaluate_only: true
load_path: (2, 1800)

+ 17
- 0
configs/trans_e_wn18.yaml View File

@@ -0,0 +1,17 @@
defaults:
- config
- override common: common
- override data: wn18
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransE_WN18"
evaluate_only: true
load_path: (82, 3000)

+ 18
- 0
configs/trans_e_wn18rr.yaml View File

@@ -0,0 +1,18 @@
defaults:
- config
- override common: common
- override data: wn18rr
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"


common:
run_name: "TransE_WN18rr"
evaluate_only: true
load_path: (2, 6000)

+ 17
- 0
configs/trans_e_yago3_10.yaml View File

@@ -0,0 +1,17 @@
defaults:
- config
- override common: common
- override data: yago3_10
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransE_YAGO310"
evaluate_only: true
load_path: (1, 2600)

+ 17
- 0
configs/trans_r.yaml View File

@@ -0,0 +1,17 @@
defaults:
- config
- override common: common
- override data: wn18
- override model: trans_r
- override training: trans_r_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransR_WN18"
# evaluate_only: true
# load_path: (1, 1000)

+ 6
- 0
data/__init__.py View File

@@ -0,0 +1,6 @@
from .fb15k import FB15KDataset
from .wn18 import WN18Dataset, WN18RRDataset
from .yago3_10 import YAGO310Dataset
from .hationet import HetionetDataset
from .open_bio_link import OpenBioLinkDataset
from .openke_wiki import OpenKEWikiDataset

+ 49
- 0
data/fb15k.py View File

@@ -0,0 +1,49 @@
from typing import List, Tuple

import torch
from pykeen.datasets import FB15k

from .kg_dataset import KGDataset


class FB15KDataset(KGDataset):
"""A wrapper around PyKeen's FB15k dataset producing KGDataset instances."""

def __init__(self, split: str = "train") -> None:
dataset = FB15k()

# if split == "train":
if True:
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'"
raise ValueError(msg)

custom_triples = torch.load(
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt",
map_location=torch.device("cpu"),
).to(torch.long)
# .tolist()

tf = tf.clone_and_exchange_triples(
custom_triples,
)

triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=dataset.num_entities,
num_relations=dataset.num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
return torch.tensor(super().__getitem__(idx), dtype=torch.long)

+ 57
- 0
data/hationet.py View File

@@ -0,0 +1,57 @@
from typing import List, Tuple

import torch
from pykeen.datasets import Hetionet

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class HetionetDataset(KGDataset):
"""A wrapper around PyKeen's Hetionet dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = Hetionet()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 89
- 0
data/kg_dataset.py View File

@@ -0,0 +1,89 @@
"""PyTorch Dataset wrapping indexed triples (h, r, t)."""

from pathlib import Path
from typing import Dict, List, Tuple

import torch
from pykeen.typing import MappedTriples
from torch.utils.data import Dataset


class KGDataset(Dataset):
"""PyTorch Dataset wrapping indexed triples (h, r, t)."""

def __init__(
self,
triples_factory: MappedTriples,
triples: List[Tuple[int, int, int]],
num_entities: int,
num_relations: int,
split: str = "train",
) -> None:
super().__init__()
self.triples_factory = triples_factory
self.triples = torch.tensor(triples, dtype=torch.long)
self.num_entities = num_entities
self.num_relations = num_relations
self.split = split

def __len__(self) -> int:
"""Return the number of triples in the dataset."""
return len(self.triples)

def __getitem__(self, idx: int) -> Tuple[int, int, int]:
"""Return the indexed triple at the given index."""
return self.triples[idx]

def entities(self) -> torch.Tensor:
"""Return a list of all entities in the dataset."""
return torch.arange(self.num_entities)

def relations(self) -> torch.Tensor:
"""Return a list of all relations in the dataset."""
return torch.arange(self.num_relations)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(split={self.split!r}, "
f"num_triples={len(self.triples)}, "
f"num_entities={self.num_entities}, "
f"num_relations={self.num_relations})"
)


# ---------------------------- Helper functions --------------------------------


def _map_to_ids(
triples: List[Tuple[str, str, str]],
) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]:
ent2id: Dict[str, int] = {}
rel2id: Dict[str, int] = {}
mapped: List[Tuple[int, int, int]] = []

def _id(d: Dict[str, int], k: str) -> int:
d.setdefault(k, len(d))
return d[k]

for h, r, t in triples:
mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t)))

return mapped, ent2id, rel2id


def load_kg_from_files(
root: str, splits: tuple[str] = ("train", "valid", "test")
) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]:
root = Path(root)
split_raw, all_raw = {}, []
for s in splits:
lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()]
split_raw[s] = lines
all_raw.extend(lines)
mapped, ent2id, rel2id = _map_to_ids(all_raw)
datasets, start = {}, 0
for s in splits:
end = start + len(split_raw[s])
datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s)
start = end
return datasets, ent2id, rel2id

+ 57
- 0
data/open_bio_link.py View File

@@ -0,0 +1,57 @@
from typing import List, Tuple

import torch
from pykeen.datasets import OpenBioLink

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class OpenBioLinkDataset(KGDataset):
"""A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = OpenBioLink()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 57
- 0
data/openke_wiki.py View File

@@ -0,0 +1,57 @@
from typing import List, Tuple

import torch
from pykeen.datasets import Wikidata5M

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class OpenKEWikiDataset(KGDataset):
"""A wrapper around PyKeen's OpenKE-Wiki (WikiN3) dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = Wikidata5M()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 112
- 0
data/wn18.py View File

@@ -0,0 +1,112 @@
from typing import List, Tuple

import torch
from pykeen.datasets import WN18, WN18RR

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class WN18Dataset(KGDataset):
"""A wrapper around PyKeen's WN18 dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = WN18()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)


# Assuming KGDataset is defined in the same module or already imported:
# from your_module import KGDataset


class WN18RRDataset(KGDataset):
"""A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18RR dataset
dataset = WN18RR()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 70
- 0
data/yago3_10.py View File

@@ -0,0 +1,70 @@
from typing import List, Tuple

import torch
from pykeen.datasets import PathDataset

from .kg_dataset import KGDataset


class YAGO310Dataset(KGDataset):
"""Wrapper around PyKeen's YAGO3-10 dataset."""

def __init__(self, split: str = "train") -> None:
# dataset = YAGO310()

dataset = PathDataset(
training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt",
testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt",
validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt",
)

if split == "train":
# if True:
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'"
raise ValueError(msg)

custom_triples = torch.load(
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt",
map_location=torch.device("cpu"),
).to(torch.long)
# .tolist()

# get a random sample of size 5000
# seed = torch.seed() # For reproducibility
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# if custom_triples.shape[0] > 5000:
# custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]]

tf = tf.clone_and_exchange_triples(
custom_triples,
)

# triples = tf.mapped_triples
# # get a random sample of size 5000
# if triples.shape[0] > 5000:
# indices = torch.randperm(triples.shape[0])[:5000]
# triples = triples[indices]

# tf = tf.clone_and_exchange_triples(triples)

triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=dataset.num_entities,
num_relations=dataset.num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
return torch.tensor(super().__getitem__(idx), dtype=torch.long)

+ 38
- 0
eval_datasets.py View File

@@ -0,0 +1,38 @@
from data import (
FB15KDataset,
WN18Dataset,
WN18RRDataset,
YAGO310Dataset,
)
from metrics.wlcrec import WLCREC


def main() -> None:
datasets = {
"YAGO3-10": YAGO310Dataset(split="train"),
# "WN18": WN18Dataset(split="train"),
# "WN18RR": WN18RRDataset(split="train"),
# "FB15K": FB15KDataset(split="train"),
# "Hetionet": HetionetDataset(split="train"),
# "OpenBioLink": OpenBioLinkDataset(split="train"),
# "OpenKEWiki": OpenKEWikiDataset(split="train"),
}

for name, dataset in datasets.items():
wl_crec = WLCREC(dataset)
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = wl_crec.compute(H=20)
print(
f"\nDataset: {name}",
f"\nResults (H={20})",
f"\n • avg WLEC : {entropy:.6f} nats",
f"\n • C_ratio : {c_ratio:.6f}",
f"\n • C_NWLEC : {c_nwlec:.6f}",
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats",
f"\n • D_ratio : {d_ratio:.6f}",
f"\n • D_NWLEC : {d_nwlec:.6f}",
sep="",
)


if __name__ == "__main__":
main()

+ 80
- 0
main.py View File

@@ -0,0 +1,80 @@
from copy import deepcopy

import hydra
import lovely_tensors as lt
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig

from metrics.c_swklf import CSWKLF
from tools import get_pretty_logger

# (Import whatever Trainer or runner you have)
from training.trainer import Trainer

logger = get_pretty_logger(__name__)

lt.monkey_patch()


@hydra.main(config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
# Detect CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

common_params = cfg.common
training_params = deepcopy(cfg.training)
# drop the model_trainer from cfg.trainer
del training_params.model_trainer

common_params = instantiate(common_params)
training_params = instantiate(training_params)

# dataset instantiation
train_dataset = instantiate(cfg.data.train)
val_dataset = instantiate(cfg.data.valid)

model = instantiate(
cfg.model,
num_entities=train_dataset.num_entities,
num_relations=train_dataset.num_relations,
triples_factory=train_dataset.triples_factory,
device=device,
)

model = model.to(device)

model_trainer = instantiate(
cfg.training.model_trainer,
model=model,
training_params=training_params,
triples_factory=train_dataset.triples_factory,
mapped_triples=train_dataset.triples,
device=model.device,
)

# Instantiate the Trainer with the model_trainer
trainer = Trainer(
train_dataset=train_dataset,
val_dataset=val_dataset,
model_trainer=model_trainer,
common_params=common_params,
training_params=training_params,
device=model.device,
)

if cfg.common.evaluate_only:
trainer.evaluate()
c_swklf_metric = CSWKLF(
dataset=val_dataset,
model=model,
)
c_swklf_value = c_swklf_metric.compute()
logger.info(f"CSWKLF Metric Value: {c_swklf_value}")

else:
trainer.run()


if __name__ == "__main__":
main()

+ 0
- 0
metrics/__init__.py View File


+ 55
- 0
metrics/base_metric.py View File

@@ -0,0 +1,55 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Tuple

if TYPE_CHECKING:
from data.kg_dataset import KGDataset


class BaseMetric(ABC):
"""Base class for metrics that compute Weisfeiler-Lehman Entropy
Complexity (WLEC) or similar metrics.
"""

def __init__(self, dataset: KGDataset) -> None:
self.dataset = dataset
self.num_entities = dataset.num_entities

self.out_adj, self.in_adj = self._build_adjacency_lists()

def _build_adjacency_lists(
self,
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
"""Create *in* and *out* adjacency lists from mapped triples.

Parameters
----------
triples:
Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``).
num_entities:
Total number of entities. Needed to allocate adjacency containers.

Returns
-------
out_adj, in_adj
Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element
``in_adj[v]`` is a list ``[(r, h), ...]``.
"""

triples = self.dataset.triples # (m, 3)
num_entities = self.num_entities

out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)]
in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)]
for h, r, t in triples.tolist():
out_adj[h].append((r, t))
in_adj[t].append((r, h))
return out_adj, in_adj

@abstractmethod
def compute(
self,
*args,
**kwargs,
) -> Any: ...

+ 264
- 0
metrics/c_swklf.py View File

@@ -0,0 +1,264 @@
from __future__ import annotations

import math
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Iterable

import torch
from torch import Tensor

from .base_metric import BaseMetric

if TYPE_CHECKING:
from data.kg_dataset import KGDataset
from models.base_model import ERModel


@torch.no_grad()
def _log_prob_distribution(
fn: Callable[[Tensor], Tensor],
query: Tensor,
batch_size: int,
) -> Iterable[Tensor]:
"""
Yield *log*-probability batches for an arbitrary distribution helper.

Parameters
----------
fn :
A bound method of *model* returning **log**-probabilities, e.g.
``model.tail_distribution(..., log=True)``.
query :
LongTensor of shape (N, 2) containing the query IDs that `fn` expects.
batch_size :
Mini-batch size. Choose according to available GPU/CPU memory.
"""
for i in range(0, len(query), batch_size):
yield fn(query[i : i + batch_size]) # (B, |Candidates|)


def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor:
"""
KL(q || p) where q is uniform over *true_index_sets*.

Parameters
----------
log_p :
Tensor of shape (B, N) — log-probabilities from the model.
true_index_sets :
List (length B). Element *b* contains the list of **indices**
that are correct for row *b*.

Returns
-------
kl : Tensor, shape (B,) — KL divergence per row (natural log).
"""
rows: list[Tensor] = []
for lp_row, idx in zip(log_p, true_index_sets):
k = len(idx)
rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p]
return torch.tensor(rows, device=log_p.device)


@dataclass(slots=True)
class _Accum:
"""Simple accumulator for ∑ value and ∑ weight."""

tot_v: float = 0.0
tot_w: float = 0.0

def update(self, value: float, weight: float) -> None:
self.tot_v += value * weight
self.tot_w += weight

@property
def mean(self) -> float:
return self.tot_v / self.tot_w if self.tot_w else float("nan")


class CSWKLF(BaseMetric):
"""
C-SWKLF metric for evaluating knowledge graph embeddings.
This metric is used to evaluate the quality of knowledge graph embeddings
by computing the C-SWKLF score.
"""

def __init__(self, dataset: KGDataset, model: ERModel) -> None:
super().__init__(dataset)
self.model = model

@torch.no_grad()
def compute(
self,
alpha: float = 1 / 3,
beta: float = 1 / 3,
gamma: float = 1 / 3,
batch_size: int = 1024,
slice_size: int | None = 2048,
device: torch.device | str | None = None,
) -> float:
"""
Compute *Comprehensive Structural-Weighted KL-Fitness*.

Parameters
----------
model :
A trained PyKEEN ERModel **with** the three distribution helpers.
kg :
`pykeen.triples.KGInfo` instance providing `mapped_triples`.
alpha, beta, gamma :
Non-negative weights for the three query types, default 1/3 each.
Must sum to *1*.
batch_size :
Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM.
slice_size :
Forwarded to `tail_distribution` / `head_distribution` /
`relation_distribution`. `None` disables slicing.
device :
Where to do the computation. `None` ⇒ first param's device.

Returns
-------
score : float in (0, 1]
Higher = model closer to empirical distributions across *all* directions.
"""
# --------------------------------------------------------------------- #
# Preparation #
# --------------------------------------------------------------------- #
assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1."

model_device = next(self.model.parameters()).device
if device is None:
device = model_device
triples = self.dataset.triples.to(device) # (|T|, 3)

heads, rels, tails = triples.t() # (|T|,)

# Build index structures --------------------------------------------------
tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list)
heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list)
rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list)

for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()):
tails_by_hr[(h, r)].append(t)
heads_by_rt[(r, t)].append(h)
rels_by_ht[(h, t)].append(r)

# Structural entropies & query lists ---------------------------------------
H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r
H_head: dict[int, _Accum] = defaultdict(_Accum)
tail_queries: list[tuple[int, int]] = [] # (h,r)
head_queries: list[tuple[int, int]] = [] # (r,t)
rel_queries: list[tuple[int, int]] = [] # (h,t)

# --- tails --------------------------------------------------------------
for (h, r), ts in tails_by_hr.items():
k = len(ts)
H_tail[r].update(math.log(k), 1.0)
tail_queries.append((h, r))

# --- heads --------------------------------------------------------------
for (r, t), hs in heads_by_rt.items():
k = len(hs)
H_head[r].update(math.log(k), 1.0)
head_queries.append((r, t))

# --- relations ----------------------------------------------------------
H_rel_accum = _Accum()
for (h, t), rs in rels_by_ht.items():
k = len(rs)
H_rel_accum.update(math.log(k), 1.0)
rel_queries.append((h, t))

H_struct_tail = {r: acc.mean for r, acc in H_tail.items()}
H_struct_head = {r: acc.mean for r, acc in H_head.items()}
# H_struct_rel = H_rel_accum.mean # scalar

# --------------------------------------------------------------------- #
# KL divergences #
# --------------------------------------------------------------------- #
def _avg_kl_per_relation(
queries: list[tuple[int, int]],
true_idx_map: dict[tuple[int, int], list[int]],
distribution_fn: Callable[[Tensor], Tensor],
num_relations: bool = False, # switch to know output size
) -> dict[int, float]:
"""
Compute *average* KL(q || p) **per relation**.

Works for the tail & head conditionals.
"""
kl_sum: dict[int, float] = defaultdict(float)
count: dict[int, int] = defaultdict(int)

query_tensor = torch.tensor(queries, dtype=torch.long, device=device)

for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size):
# slice_size handled inside distribution_fn
# log_p : (B, |Candidates|)
b = log_p.shape[0]
# Map back to current slice of queries
slice_queries = queries[:b]
queries[:] = queries[b:] # consume

true_sets = [true_idx_map[q] for q in slice_queries]
kl_row = _kl_uniform(log_p, true_sets) # (B,)

for (x, r), kl_val in zip(slice_queries, kl_row.tolist()):
rel_id = r if not num_relations else x
kl_sum[rel_id] += kl_val
count[rel_id] += 1

return {r: kl_sum[r] / count[r] for r in kl_sum}

# --- tails --------------------------------------------------------------
tail_queries_copy = tail_queries.copy()
D_tail = _avg_kl_per_relation(
tail_queries_copy,
tails_by_hr,
lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s),
)
S_tail = {r: math.exp(-d) for r, d in D_tail.items()}

# --- heads --------------------------------------------------------------
head_queries_copy = head_queries.copy()
D_head = _avg_kl_per_relation(
head_queries_copy,
heads_by_rt,
lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s),
num_relations=True,
)
S_head = {r: math.exp(-d) for r, d in D_head.items()}

# --- relations (single global number) ----------------------------------
D_rel_sum = 0.0
for log_p in _log_prob_distribution(
lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s),
torch.tensor(rel_queries, dtype=torch.long, device=device),
batch_size,
):
slice_size_batch = log_p.shape[0]
slice_queries = rel_queries[:slice_size_batch]
rel_queries[:] = rel_queries[slice_size_batch:]

true_sets = [rels_by_ht[q] for q in slice_queries]
D_rel_sum += _kl_uniform(log_p, true_sets).sum().item()

D_rel = D_rel_sum / H_rel_accum.tot_w
S_rel = math.exp(-D_rel)

# --------------------------------------------------------------------- #
# Weighted aggregation → C-SWKLF #
# --------------------------------------------------------------------- #
def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float:
num = sum(weight[r] * score[r] for r in score)
den = sum(weight[r] for r in score)
return num / den if den else float("nan")

sw_tail = _weighted_mean(S_tail, H_struct_tail)
sw_head = _weighted_mean(S_head, H_struct_head)
# Relations: numerator & denominator cancel
sw_rel = S_rel

return alpha * sw_tail + beta * sw_head + gamma * sw_rel

+ 732
- 0
metrics/crec_modifier.py View File

@@ -0,0 +1,732 @@
# from __future__ import annotations

# import math
# import os
# import random
# import threading
# from collections import defaultdict
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from typing import Dict, List, Tuple

# import torch

# from data.kg_dataset import KGDataset
# from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in
# from tools import get_pretty_logger

# logger = get_pretty_logger(__name__)


# # --------------------------------------------------------------------
# # Edge-editing primitives
# # --------------------------------------------------------------------
# # -- additions --------------------
# def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng):
# for _ in range(1000):
# h = rng.randrange(n_ent)
# t = rng.randrange(n_ent)
# if colours[h] != colours[t]:
# r = rng.randrange(n_rel)
# triples.append((h, r, t))
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))
# return


# def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng):
# σ = rng.choice(colours)
# τ = rng.choice(colours)
# h = colours.index(σ)
# t = colours.index(τ)
# triples.append((h, 0, t))
# out_adj[h].append((0, t))
# in_adj[t].append((0, h))


# # -- removals --------------------


# def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng):
# cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
# if cand:
# idx = rng.choice(cand)
# h, r, t = triples.pop(idx)
# out_adj[h].remove((r, t))
# in_adj[t].remove((r, h))


# def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng):
# sig_rel = defaultdict(set)
# for h, r, t in triples:
# σ, τ = colours[h], colours[t]
# sig_rel[(σ, τ)].add(r)
# cand = []
# for i, (h, r, t) in enumerate(triples):
# if len(sig_rel[(colours[h], colours[t])]) == 1:
# cand.append(i)
# if cand:
# idx = rng.choice(cand)
# h, r, t = triples.pop(idx)
# out_adj[h].remove((r, t))
# in_adj[t].remove((r, h))


# def _search_worker(
# seed: int,
# triples_init: List[Tuple[int, int, int]],
# triples_factory,
# n_ent: int,
# n_rel: int,
# depth: int,
# lo: float,
# hi: float,
# max_iters: int,
# ) -> List[Tuple[int, int, int]] | None:
# """
# Run the exact same hill‑climb that `tune_crec()` did,
# but entirely in this process. Return the edited triples
# once c falls in [lo, hi]; return None if we used up all iterations.
# """
# rng = random.Random(seed)
# triples = triples_init.copy().tolist()

# for it in range(max_iters):
# # WL‑CREC is *only* recomputed every 1000 edits, exactly like before
# if it % 1000 == 0:
# dataset = KGDataset(
# triples_factory,
# triples=torch.tensor(triples, dtype=torch.long),
# num_entities=n_ent,
# num_relations=n_rel,
# )
# wl = WLCREC(dataset)
# colours = wl.wl_colours(depth)
# *_, c, _ = wl.compute(H=5) # unchanged API

# if lo <= c <= hi: # success
# logger.info("[seed %d] hit %.4f in %d edits", seed, c, it)
# return triples

# # ---------- identical edit logic ----------
# if c < lo:
# if rng.random() < 0.5:
# rem_det_edge(triples, colours, depth, rng)
# else:
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng)
# else: # c > hi
# if rng.random() < 0.5:
# rem_div_edge(triples, colours, depth, rng)
# else:
# add_det_edge(triples, colours, depth, n_rel, rng)
# # ------------------------------------------

# return None # used up our budget


# # --------------------------------------------------------------------
# # Unified tuner
# # --------------------------------------------------------------------
# def tune_crec(
# wl_crec: WLCREC,
# n_ent: int,
# n_rel: int,
# depth: int,
# lo: float,
# hi: float,
# max_iters: int = 80_000,
# seed: int = 42,
# ):
# triples = wl_crec.dataset.triples.tolist()

# rng = random.Random(seed)
# for it in range(max_iters):
# print(f"\r[iter {it + 1:5d}] ", end="")
# if it % 1000 == 0:
# dataset = KGDataset(
# wl_crec.dataset.triples_factory,
# triples=torch.tensor(triples, dtype=torch.long),
# num_entities=n_ent,
# num_relations=n_rel,
# )
# tmp_wl_crec = WLCREC(dataset)

# # colours = wl_colours(triples, n_ent, depth)
# colours = tmp_wl_crec.wl_colours(depth)
# _, _, _, _, c, _ = tmp_wl_crec.compute(H=5)
# if lo <= c <= hi:
# logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples))
# return triples
# if c < lo:
# # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying
# if rng.random() < 0.5:
# rem_det_edge(triples, colours, depth, rng)
# else:
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng)
# # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic
# elif rng.random() < 0.5:
# rem_div_edge(triples, colours, depth, rng)
# else:
# add_det_edge(triples, colours, depth, n_rel, rng)
# if (it + 1) % 10000 == 1:
# logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples))
# raise RuntimeError("Exceeded max iterations without hitting target band.")


# def _edit_batch(
# worker_id: int,
# n_edits: int,
# colours: List[int],
# crec: float,
# target_lo: float,
# target_hi: float,
# depth: int,
# n_ent: int,
# n_rel: int,
# seed_base: int,
# ):
# """
# Perform `n_edits` topology modifications *locally* and return
# (added_triples, removed_triples) lists.
# """
# rng = random.Random(seed_base + worker_id)
# local_added, local_removed = [], []

# # local graph views: only sizes matter for edit selection
# for _ in range(n_edits):
# if crec < target_lo: # need ↑ WL-CREC
# if rng.random() < 0.5: # • remove deterministic
# # We cannot remove by index safely without the global list;
# # choose a *signature* and remember intention to delete:
# local_removed.append(("det", rng.random()))
# else: # • add diversifying
# h = rng.randrange(n_ent)
# t = rng.randrange(n_ent)
# while colours[h] == colours[t]:
# h = rng.randrange(n_ent)
# t = rng.randrange(n_ent)
# r = rng.randrange(n_rel)
# local_added.append((h, r, t))
# else: # need ↓ WL-CREC
# if rng.random() < 0.5: # • remove diversifying
# local_removed.append(("div", rng.random()))
# else: # • add deterministic
# σ = rng.choice(colours)
# τ = rng.choice(colours)
# h = colours.index(σ)
# t = colours.index(τ)
# local_added.append((h, 0, t))

# return local_added, local_removed


# def wl_colours(triples, n_ent, depth):
# # build adjacency once
# out_adj, in_adj = defaultdict(list), defaultdict(list)
# for h, r, t in triples:
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))

# colours_rounds = [[0] * n_ent] # round-0 colours
# for h in range(1, depth + 1):
# prev = colours_rounds[-1]

# # 1) build textual signatures in parallel
# def sig(v):
# neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
# ("↑", r, prev[u]) for r, u in in_adj.get(v, [])
# ]
# neigh.sort()
# return (prev[v], tuple(neigh))

# with ThreadPoolExecutor() as tpe: # cheap threads inside worker
# sigs = list(tpe.map(sig, range(n_ent)))

# # 2) assign deterministic colour IDs
# sig2id: Dict[Tuple, int] = {}
# next_round = [0] * n_ent
# fresh = 0
# for v, sg in enumerate(sigs):
# cid = sig2id.setdefault(sg, fresh)
# if cid == fresh:
# fresh += 1
# next_round[v] = cid
# colours_rounds.append(next_round)

# depth_colours = colours_rounds[-1]

# return depth_colours


# def _metric_worker(args):
# triples, n_ent, n_rel, depth = args
# dataset = KGDataset(
# triples_factory=None, # Not used in this context
# triples=torch.tensor(triples, dtype=torch.long),
# num_entities=n_ent,
# num_relations=n_rel,
# )
# wl_crec = WLCREC(dataset)
# _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False)
# colours = wl_colours(triples, n_ent, depth)
# return c, colours


# def tune_crec_parallel_edits(
# triples_init,
# n_ent: int,
# n_rel: int,
# depth: int,
# target_lo: float,
# target_hi: float,
# max_iters: int = 80_000,
# metric_every: int = 100,
# n_workers: int = max(20, math.ceil(os.cpu_count() / 2)),
# seed: int = 42,
# ):
# # -------- shared mutable state (main thread owns it) --------
# triples = triples_init.tolist()
# out_adj, in_adj = defaultdict(list), defaultdict(list)
# for h, r, t in triples:
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))

# pool = ThreadPoolExecutor(max_workers=n_workers)
# metric_lock = threading.Lock() # exactly one metric at a time
# rng_global = random.Random(seed)

# # ----- first metric checkpoint -----
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth))

# edit_budget_total = 0
# for it in range(0, max_iters, metric_every):
# # =========================================================
# # 1. PARALLEL EDIT STAGE (metric_every edits in total)
# # =========================================================
# futures = []
# edits_per_worker = metric_every // n_workers
# extra = metric_every % n_workers
# for wid in range(n_workers):
# n_edits = edits_per_worker + (1 if wid < extra else 0)
# futures.append(
# pool.submit(
# _edit_batch,
# wid,
# n_edits,
# colours,
# crec,
# target_lo,
# target_hi,
# depth,
# n_ent,
# n_rel,
# seed,
# )
# )

# # merge when workers finish
# for fut in as_completed(futures):
# added, removed_specs = fut.result()
# # --- apply additions immediately (cheap, conflict-free)
# for h, r, t in added:
# triples.append((h, r, t))
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))
# # --- apply removals: interpret the spec on *current* graph
# for typ, randv in removed_specs:
# if typ == "div":
# idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
# else: # 'det'
# sig_rel = defaultdict(set)
# for h, r, t in triples:
# sig_rel[(colours[h], colours[t])].add(r)
# idxs = [
# i
# for i, (h, r, t) in enumerate(triples)
# if len(sig_rel[(colours[h], colours[t])]) == 1
# ]
# if idxs:
# victim = idxs[int(randv * len(idxs))]
# h, r, t = triples.pop(victim)
# out_adj[h].remove((r, t))
# in_adj[t].remove((r, h))

# edit_budget_total += metric_every

# # =========================================================
# # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised)
# # =========================================================
# # if edit_budget_total % (10 * metric_every) == 0:
# if True:
# with metric_lock:
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth))
# logging.info(
# "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples)
# )

# if target_lo <= crec <= target_hi:
# logging.info("Target band reached.")
# pool.shutdown(wait=True)
# return triples

# pool.shutdown(wait=True)
# raise RuntimeError("Exceeded max iteration budget without success.")


# from __future__ import annotations

# import logging
# import random
# from collections import defaultdict
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from typing import List, Tuple

# import torch

# from data.kg_dataset import KGDataset
# from metrics.wlcrec import WLCREC
# from tools import get_pretty_logger

# logger = get_pretty_logger(__name__)

# # ------------ additions ------------
# def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random):
# for _ in range(1000):
# h, t = rng.randrange(n_ent), rng.randrange(n_ent)
# if colours[h] != colours[t]:
# r = rng.randrange(n_rel)
# return ("add", (h, r, t))
# return None # fell through – extremely rare


# def propose_add_det(colours, n_rel: int, rng: random.Random):
# σ, τ = rng.choice(colours), rng.choice(colours)
# h, t = colours.index(σ), colours.index(τ)
# return ("add", (h, 0, t)) # rel 0 = deterministic


# # ------------ removals ------------
# def propose_rem_div(triples: List, colours, rng: random.Random):
# cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]]
# return ("rem", rng.choice(cand)) if cand else None


# def propose_rem_det(triples: List, colours, rng: random.Random):
# sig_rel = defaultdict(set)
# for h, r, t in triples:
# sig_rel[(colours[h], colours[t])].add(r)
# cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1]
# return ("rem", rng.choice(cand)) if cand else None


# def make_edit_proposal(
# triples_snapshot: List,
# colours_snapshot,
# c: float,
# lo: float,
# hi: float,
# n_ent: int,
# n_rel: int,
# depth: int, # still here for future use / signatures
# seed: int,
# ):
# """Return exactly one ('add' | 'rem', triple) proposal or None."""
# rng = random.Random(seed)

# # -- decide which kind of edit we want, *given* the current c ------------
# if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying)
# chooser = (propose_rem_det, propose_add_div)
# else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic)
# chooser = (propose_rem_div, propose_add_det)

# op = rng.choice(chooser)
# return (
# op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng)
# if op.__name__.startswith("propose_add")
# else op(triples_snapshot, colours_snapshot, rng)
# )


# def tune_crec_parallel_edits(
# triples: List,
# n_ent: int,
# n_rel: int,
# depth: int,
# lo: float,
# hi: float,
# *,
# max_iters: int = 80_000,
# edits_per_eval: int = 1000, # == old “if it % 1000 == 0”
# batch_size: int = 256, # how many proposals we farm out at once
# max_workers: int = 4,
# seed: int = 42,
# ) -> List:
# rng_global = random.Random(seed)

# triples = set(triples) # deduplicate if needed

# with ThreadPoolExecutor(max_workers=max_workers) as pool:
# proposal_seed = seed * 997 # deterministic but different stream

# # edit_counter = 0
# # while edit_counter < max_iters:
# # # ----------------- expensive part (single‑thread) ----------------
# # dataset = KGDataset(
# # triples_factory=None, # Not used in this context
# # triples=torch.tensor(triples, dtype=torch.long),
# # num_entities=n_ent,
# # num_relations=n_rel,
# # )
# # tmp = WLCREC(dataset)
# # colours = tmp.wl_colours(depth)
# # *_, c, _ = tmp.compute(H=5)
# # # -----------------------------------------------------------------

# # if lo <= c <= hi:
# # logger.info(
# # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples)
# # )
# # return triples

# # ============ parallel block: just make `edits_per_eval` proposals
# needed = min(edits_per_eval, max_iters - edit_counter)
# proposals = []

# while len(proposals) < needed:
# # launch a batch of workers
# futs = [
# pool.submit(
# make_edit_proposal,
# triples,
# colours,
# c,
# lo,
# hi,
# n_ent,
# n_rel,
# depth,
# proposal_seed + i,
# )
# for i in range(batch_size)
# ]
# for f in as_completed(futs):
# prop = f.result()
# if prop is not None:
# proposals.append(prop)
# if len(proposals) == needed:
# break
# proposal_seed += batch_size # move RNG window forward

# # -------------- apply the gathered proposals *sequentially* -------
# for kind, trp in proposals:
# if kind == "add":
# triples.append(trp)
# else: # "rem"
# try:
# triples.remove(trp)
# except ValueError:
# pass # already gone – benign collision
# # -----------------------------------------------------------------
# edit_counter += needed
# if edit_counter % 1_000 == 0:
# logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples))

# raise RuntimeError("Exceeded max_iters without hitting target band.")


import os
import random
import threading
from collections import defaultdict
from typing import Dict, List, Tuple

# --------------------------------------------------------------------
# Logging helper (replace with your own if you prefer) --------------
# --------------------------------------------------------------------
from tools import get_pretty_logger

logging = get_pretty_logger(__name__)

# --------------------------------------------------------------------
# Original edge‑editing primitives (unchanged) ----------------------
# --------------------------------------------------------------------


def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng):
for _ in range(1000):
h = rng.randrange(n_ent)
t = rng.randrange(n_ent)
if colours[h] != colours[t]:
r = rng.randrange(n_rel)
triples.append((h, r, t))
out_adj[h].append((r, t))
in_adj[t].append((r, h))
return


def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng):
σ = rng.choice(colours)
τ = rng.choice(colours)
h = colours.index(σ)
t = colours.index(τ)
triples.append((h, 0, t))
out_adj[h].append((0, t))
in_adj[t].append((0, h))


def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng):
cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
if cand:
idx = rng.choice(cand)
h, r, t = triples.pop(idx)
out_adj[h].remove((r, t))
in_adj[t].remove((r, h))


def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng):
sig_rel = defaultdict(set)
for h, r, t in triples:
σ, τ = colours[h], colours[t]
sig_rel[(σ, τ)].add(r)
cand = []
for i, (h, r, t) in enumerate(triples):
if len(sig_rel[(colours[h], colours[t])]) == 1:
cand.append(i)
if cand:
idx = rng.choice(cand)
h, r, t = triples.pop(idx)
out_adj[h].remove((r, t))
in_adj[t].remove((r, h))


def _worker(
*,
worker_id: int,
rng: random.Random,
triples: List[Tuple[int, int, int]],
out_adj: Dict[int, List[Tuple[int, int]]],
in_adj: Dict[int, List[Tuple[int, int]]],
colours: List[int],
n_ent: int,
n_rel: int,
depth: int,
c: float,
lo: float,
hi: float,
max_iters: int,
state_lock: threading.Lock,
stop_event: threading.Event,
):
"""One thread: mutate the *shared* structures until success/stop."""
for it in range(max_iters):
if stop_event.is_set():
return # someone else finished ─ exit early

with state_lock: # protect the shared graph
# if lo <= c <= hi:
# logging.info(
# "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)",
# worker_id,
# it,
# c,
# len(triples),
# )
# stop_event.set()
# return

# Choose and apply one edit -----------------------------
if c < lo: # need ↑ CREC
if rng.random() < 0.5:
rem_det_edge(triples, out_adj, in_adj, colours, depth, rng)
else:
add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng)
elif rng.random() < 0.5:
rem_div_edge(triples, out_adj, in_adj, colours, depth, rng)
else:
add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng)

logging.warning("[worker %d] reached max_iters", worker_id)
stop_event.set()
return


# --------------------------------------------------------------------
# Public API --------------------------------------------------------
# --------------------------------------------------------------------


def fast_tune_crec(
triples: List,
colours: List,
n_ent: int,
n_rel: int,
depth: int,
c: float,
lo: float,
hi: float,
max_iters: int = 1000,
max_workers: int | None = None,
seeds: List[int] | None = None,
) -> List[Tuple[int, int, int]]:
"""Tune WL‑CREC with *shared* triples using multiple threads.

Returns the **same list instance** that was passed in – already
modified in place by the winning thread.
"""
if max_workers is None:
# max_workers = os.cpu_count() or 4
max_workers = 4
if seeds is None:
seeds = [42 + i for i in range(max_workers)]
assert len(seeds) >= max_workers, "Need at least one seed per worker"

# Prepare adjacency once (shared) --------------------------------
out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
for h, r, t in triples:
out_adj[h].append((r, t))
in_adj[t].append((r, h))

state_lock = threading.Lock()
stop_event = threading.Event()

logging.info(
"Launching %d threads on shared triples (target %.3f–%.3f)",
max_workers,
lo,
hi,
)

threads: List[threading.Thread] = []
for wid in range(max_workers):
t = threading.Thread(
name=f"tune‑crec‑worker‑{wid}",
target=_worker,
kwargs=dict(
worker_id=wid,
rng=random.Random(seeds[wid]),
triples=triples,
out_adj=out_adj,
in_adj=in_adj,
colours=colours,
n_ent=n_ent,
n_rel=n_rel,
depth=depth,
c=c,
lo=lo,
hi=hi,
max_iters=max_iters,
state_lock=state_lock,
stop_event=stop_event,
),
daemon=False,
)
threads.append(t)
t.start()

for t in threads:
t.join()

if not stop_event.is_set():
raise RuntimeError("No thread converged – try increasing max_iters or widen band")

return triples

+ 106
- 0
metrics/crec_radius_sample.py View File

@@ -0,0 +1,106 @@
from __future__ import annotations

import logging
import random
from collections import defaultdict, deque
from math import isclose
from typing import List

import torch

from data.kg_dataset import KGDataset
from metrics.wlcrec import WLCREC
from tools import get_pretty_logger

logging = get_pretty_logger(__name__)



# --------------------------------------------------------------------
# Radius-k patch sampler
# --------------------------------------------------------------------
def radius_patch_sample(
all_triples: List,
ratio: float,
radius: int,
rng: random.Random,
) -> List:
"""
Take ~ratio fraction of triples by picking random seeds and
keeping *all* triples within ≤ radius hops (undirected) of them.
"""
n_total = len(all_triples)
target = int(ratio * n_total)

# Build entity → triple-indices adjacency for BFS
ent2trip = defaultdict(list)
for idx, (h, _, t) in enumerate(all_triples):
ent2trip[h].append(idx)
ent2trip[t].append(idx)

chosen_idx: set[int] = set()
visited_ent = set()

while len(chosen_idx) < target:
seed_idx = rng.randrange(n_total)
if seed_idx in chosen_idx:
continue
# BFS over entities
queue = deque()
for e in all_triples[seed_idx][::2]: # subject & object
if e not in visited_ent:
queue.append((e, 0))
visited_ent.add(e)
while queue:
ent, dist = queue.popleft()
for tidx in ent2trip[ent]:
if tidx not in chosen_idx:
chosen_idx.add(tidx)
if dist < radius:
for tidx in ent2trip[ent]:
h, _, t = all_triples[tidx]
for nb in (h, t):
if nb not in visited_ent:
visited_ent.add(nb)
queue.append((nb, dist + 1))
# guard against infinite loop on very small ratio
if isclose(ratio, 0.0) and chosen_idx:
break

return [all_triples[i] for i in chosen_idx]


# --------------------------------------------------------------------
# Main driver: repeated sampling until WL-CREC band hit
# --------------------------------------------------------------------
def find_sample(
all_triples: List,
n_ent: int,
n_rel: int,
depth: int,
ratio: float,
radius: int,
lower: float,
upper: float,
max_tries: int,
seed: int,
) -> List:
for trial in range(1, max_tries + 1):
rng = random.Random(seed + trial) # fresh randomness each try
sample = radius_patch_sample(all_triples, ratio, radius, rng)

sample_ds = KGDataset(
triples_factory=None,
triples=torch.tensor(sample, dtype=torch.long),
num_entities=n_ent,
num_relations=n_rel,
split="train",
)
wl_crec = WLCREC(sample_ds)

crec_val = wl_crec.compute(depth)[-2]
logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val)
if lower <= crec_val <= upper:
logging.info("Success after %d tries.", trial)
return sample
raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.")

+ 226
- 0
metrics/greedy_crec.py View File

@@ -0,0 +1,226 @@
from __future__ import annotations

import logging
import random
from math import log
from typing import Dict, List, Tuple

import torch
from torch import Tensor

from tools import get_pretty_logger

logging = get_pretty_logger(__name__)


Entity = int
Relation = int
Triple = Tuple[Entity, Relation, Entity]
Color = int


# ---------------------------------------------------------------------
# Weisfeiler–Lehman colouring (GPU friendly)
# ---------------------------------------------------------------------
def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]:
"""
GPU implementation of classic Weisfeiler‑Lehman refinement.

1. Each node starts with colour 0.
2. At iteration h we hash the multiset of
⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h))
signatures into new colour ids.
3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and
immaterial for CREC (only *relative* colours matter).
"""
h, r, t = triples.unbind(dim=1) # (|T|,)
colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0

# pre‑compute 64‑bit mix constants once
mix_r = (
torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long)
* 1_146_189_683_093_321_123
)
mix_c = 8_636_673_225_737_527_201

for _ in range(depth):
prev = colours[-1] # (|V|,)

# signatures for outgoing edges
sig_out = mix_r[r] ^ (prev[t] * mix_c)
# signatures for incoming edges
sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15

# bucket signatures back to the source/target nodes
col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out)
col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in)

# combine ↓ and ↑ multiset hashes with current colour
raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1)

# re‑map to dense consecutive ids with torch.unique
uniq, new = torch.unique(raw, sorted=True, return_inverse=True)
colours.append(new)

return colours # length = depth+1, each (|V|,) long


# ---------------------------------------------------------------------
# WL‑CREC metric
# ---------------------------------------------------------------------
def wl_crec_gpu(
triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device
) -> float:
log_n = log(max(n_ent, 2))

# ------- pattern‑diversity C -------------------------------------
C = 0.0
for c in colours: # (|V|,)
# histogram via bincount on GPU
hist = torch.bincount(c).float()
p = hist / n_ent
C += -(p * torch.log(p.clamp_min(1e-30))).sum().item()
C /= (depth + 1) * log_n

# ------- residual entropy H_c ------------------------------------
h, r, t = triples.unbind(dim=1)
sigma, tau = colours[depth][h], colours[depth][t] # (|T|,)

# encode (sigma,tau) pairs into a single 64‑bit key
sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64)

# total count per signature
sig_unique, sig_inv, sig_counts = torch.unique(
sig_keys, return_inverse=True, return_counts=True, sorted=False
)

# build 2‑D contingency table counts[(sigma,tau), r]
m = triples.size(0)
keys_2d = sig_inv * n_rel + r
_, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False)

# rel_counts is aligned with the *compact* key list; rebuild dense tensor
rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long)
rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long))

# conditional entropy per (sigma,tau)
p_r = rc_dense.float() / sig_counts.unsqueeze(1).float()
inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,)

Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2))

return C * Hc


# ---------------------------------------------------------------------
# delta‑entropy helpers (still CPU side for clarity; cost is negligible)
# ---------------------------------------------------------------------
# The deterministic delta‑formulas depend on small dictionaries and are evaluated
# on ≤256 candidate edges each iteration – copy them from the original script.
inner_entropy = ... # unchanged
delta_h_cond_remove = ... # unchanged
delta_h_cond_add = ... # unchanged


# ---------------------------------------------------------------------
# Greedy delta‑search driver
# ---------------------------------------------------------------------
def greedy_delta_tune_gpu(
triples: List[Triple],
n_ent: int,
n_rel: int,
depth: int,
lower: float,
upper: float,
device: torch.device,
max_iters: int = 40_000,
sample_size: int = 256,
seed: int = 0,
) -> List[Triple]:
# book‑keeping in Python lists (cheap); heavy math in torch on GPU
rng = random.Random(seed)

# cache torch version of triples for fast metric eval
def triples_to_tensor(ts: List[Triple]) -> Tensor:
return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False)

for it in range(max_iters):
tri_tensor = triples_to_tensor(triples)
colours = wl_colours_gpu(tri_tensor, n_ent, depth, device)
crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device)

if lower <= crec <= upper:
logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples))
return triples

# ----------------------------------------------------------------
# build signature statistics (CPU, cheap: O(|T|))
# ----------------------------------------------------------------
depth_col = colours[depth].cpu()
sig_cnt: Dict[Tuple[int, int], int] = {}
rel_cnt: Dict[Tuple[int, int, int], int] = {}

for h_idx, (h, r, t) in enumerate(triples):
sigma, tau = int(depth_col[h]), int(depth_col[t])
sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1
rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1

det_edges, div_edges = [], []
for idx, (h, r, t) in enumerate(triples):
sigma, tau = int(depth_col[h]), int(depth_col[t])
if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1:
det_edges.append(idx)
if sigma != tau:
div_edges.append(idx)

# ----------------------------------------------------------------
# candidate generation + best edit selection
# ----------------------------------------------------------------
total = len(triples)
target_high = crec < lower # need to raise WL‑CREC?
candidates = []

if target_high and det_edges:
rng.shuffle(det_edges)
for idx in det_edges[:sample_size]:
h, r, t = triples[idx]
sig = (int(depth_col[h]), int(depth_col[t]))
delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
if delta > 0:
candidates.append(("remove", idx, delta))

elif not target_high and det_edges:
rng.shuffle(det_edges)
for idx in det_edges[:sample_size]:
h, r, t = triples[idx]
sig = (int(depth_col[h]), int(depth_col[t]))
delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
if delta < 0:
candidates.append(("add", (h, r, t), delta))

# fall‑back heuristics
if not candidates:
if target_high and div_edges:
idx = rng.choice(div_edges)
candidates.append(("remove", idx, 1e-9))
elif not target_high:
idx = rng.choice(det_edges) if det_edges else rng.randrange(total)
h, r, t = triples[idx]
candidates.append(("add", (h, r, t), -1e-9))

best = (
max(candidates, key=lambda x: x[2])
if target_high
else min(candidates, key=lambda x: x[2])
)

# apply edit
if best[0] == "remove":
triples.pop(best[1])
else:
triples.append(best[1])

if (it + 1) % 1_000 == 0:
logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples))

raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.")

+ 30
- 0
metrics/ranking.py View File

@@ -0,0 +1,30 @@
import torch

from models.base_model import ERModel


def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return (scores > scores[target]).sum() + 1


def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor:
h, r, t = triples[:, 0], triples[:, 1], triples[:, 2]
E = model.num_entities
ranks = []

with torch.no_grad():
for hi, ri, ti in zip(h, r, t):
tails = torch.arange(E, device=triples.device)
triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1)
scores = model(triples)
ranks.append(_rank(scores, ti))

return torch.tensor(ranks, device=triples.device, dtype=torch.float32)


def mrr(r: torch.Tensor) -> torch.Tensor:
return (1 / r).mean()


def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor:
return (r <= k).float().mean()

+ 237
- 0
metrics/wlcrec.py View File

@@ -0,0 +1,237 @@
from __future__ import annotations

import math
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple

import torch

from .base_metric import BaseMetric

if TYPE_CHECKING:
from data.kg_dataset import KGDataset


def _entropy_from_counter(counter: Counter[int], n: int) -> float:
"""Return Shannon entropy (nats) of a colour distribution of length *n*."""

ent = 0.0
for cnt in counter.values():
if cnt:
p = cnt / n
ent -= p * math.log(p)
return ent


def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
"""Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies."""

if n <= 1:
return 0.0, 0.0
H_plus_1 = len(entropies)
log_n = math.log(n)
c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1
c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1
return c_ratio, c_nwlec


def _relation_families(
triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9
) -> torch.Tensor:
"""Return a LongTensor mapping each relation id → family id.

Two relations *r, r_inv* are put into the same family when **at least
`thresh` fraction** of the edges labelled *r* have a **single** reverse edge
labelled *r_inv*.
"""
assert 0.0 < thresh <= 1.0

# Build map (h,t) -> list of relations on that edge direction.
edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list)
for h, r, t in triples.tolist():
edge2rels[(h, t)].append(int(r))

# Count reciprocal co‑occurrences (r, r_rev).
pair_counts: Dict[Tuple[int, int], int] = defaultdict(int)
rel_totals: List[int] = [0] * num_relations
for (h, t), rels in edge2rels.items():
rev_rels = edge2rels.get((t, h))
if not rev_rels:
continue
for r in rels:
rel_totals[r] += 1
for r_rev in rev_rels:
pair_counts[(r, r_rev)] += 1

# Proposed inverse mapping r -> r_inv.
proposed: List[int | None] = [None] * num_relations
for (r, r_rev), cnt in pair_counts.items():
if cnt == 0:
continue
if cnt / rel_totals[r] >= thresh:
# majority of r's edges are reversed by r_rev
proposed[r] = r_rev

# Build family ids (transitive closure is overkill; data are simple).
family = list(range(num_relations))
for r, rinv in enumerate(proposed):
if rinv is not None:
root = min(family[r], family[rinv])
family[r] = family[rinv] = root
return torch.tensor(family, dtype=torch.long)


def _conditional_relation_entropy(
colours: List[int],
triples: torch.Tensor,
family_map: torch.Tensor,
) -> float:
counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))
m = int(triples.size(0))
fam = family_map # alias for speed

for h, r, t in triples.tolist():
# key = (colours[h], colours[t]) # ordered key (direction‑aware)
key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic)
counts[key][int(fam[r])] += 1 # ▼ use family id

h_cond = 0.0
for rel_counts in counts.values():
total_s = sum(rel_counts.values())
P_s = total_s / m
inv_total = 1.0 / total_s
for cnt in rel_counts.values():
p_rs = cnt * inv_total
h_cond -= P_s * p_rs * math.log(p_rs)
return h_cond


class WLCREC(BaseMetric):
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""

def __init__(self, dataset: KGDataset) -> None:
super().__init__(dataset)

def compute(
self,
H: int = 3,
cond_h: int = 1,
inv_thresh: float = 0.9,
return_full: bool = False,
progress: bool = True,
) -> float | Tuple[float, List[float]]:
"""Compute WL-entropy scores and the composite difficulty metric.

Parameters
----------
dataset
Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``.
H
WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers.
return_full
Also return the list of per-layer entropies.
progress
Print textual progress bar.

Returns
-------
avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC
*Six* scalars (floats). If ``return_full=True`` an extra list of
layer entropies is appended.
"""

# compute relation family map once
family_map = _relation_families(
self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh
)

# ─── WL iterations ────────────────────────────────────────────────────

n = self.dataset.num_entities
colour: List[int] = [1] * n # colour 1 for everyone
entropies: List[float] = []
colours_per_h: List[List[int]] = []

def _print_prog(i: int) -> None:
if progress:
print(f"\rWL iteration {i}/{H}", end="", flush=True)

for h in range(H + 1):
_print_prog(h)

colours_per_h.append(colour.copy())

# entropy of current colouring
freq = Counter(colour)
entropies.append(_entropy_from_counter(freq, n))

if h == H:
break

# refine colours
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
next_colour: List[int] = [0] * n

for v in range(n):
T: List[Tuple[int, int, int]] = []
for r, t in self.out_adj[v]:
T.append((r, 0, colour[t])) # outgoing
for r, h_ in self.in_adj[v]:
T.append((r, 1, colour[h_])) # incoming
T.sort()
key = (colour[v], tuple(T))
if key not in bucket:
bucket[key] = len(bucket) + 1
next_colour[v] = bucket[key]

colour = next_colour

_print_prog(H)
if progress:
print()

# ─── Normalised diversity ────────────────────────────────────────────
avg_entropy = sum(entropies) / len(entropies)
c_ratio, c_nwlec = _normalise_scores(entropies, n)

# ─── Residual relation entropy ───────────────────────────────────────
h_cond = _conditional_relation_entropy(
colours_per_h[cond_h], self.dataset.triples, family_map
)

# ─── Composite difficulty - Eq. (1) ──────────────────────────────────
d_ratio = c_ratio * h_cond
d_nwlec = c_nwlec * h_cond

result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec)
if return_full:
result += (entropies,)
return result

# --------------------------------------------------------------------
# Weisfeiler–Lehman colours (unchanged)
# --------------------------------------------------------------------
def wl_colours(self, depth: int) -> List[List[int]]:
triples = self.dataset.triples.tolist()
out_adj, in_adj = defaultdict(list), defaultdict(list)
for h, r, t in triples:
out_adj[h].append((r, t))
in_adj[t].append((r, h))

n_ent = self.dataset.num_entities
colours = [[0] * n_ent] # round-0
for h in range(1, depth + 1):
prev, nxt, sig2c = colours[-1], [0] * n_ent, {}
fresh = 0
for v in range(n_ent):
neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
("↑", r, prev[u]) for r, u in in_adj.get(v, [])
]
neigh.sort()
sig = (prev[v], tuple(neigh))
if sig not in sig2c:
sig2c[sig] = fresh
fresh += 1
nxt[v] = sig2c[sig]
colours.append(nxt)
return colours

+ 135
- 0
metrics/wlec.py View File

@@ -0,0 +1,135 @@
from __future__ import annotations

import math
from collections import Counter
from typing import TYPE_CHECKING, Dict, List, Tuple

from .base_metric import BaseMetric

if TYPE_CHECKING:
from data.kg_dataset import KGDataset


def _entropy_from_counter(counter: Counter[int], n: int) -> float:
"""Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts."""
ent = 0.0
for count in counter.values():
if count: # avoid log(0)
p = count / n
ent -= p * math.log(p)
return ent


def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
"""Return (C_ratio, C_NWLEC) given *per-iteration* entropies."""

if n <= 1:
# Degenerate graph.
return 0.0, 0.0

H_plus_1 = len(entropies)
# Eq. (1): simple ratio normalisation.
c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1

# Eq. (2): effective-colour NWLEC.
k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies]
c_nwlec = sum(k_terms) / H_plus_1
return c_ratio, c_nwlec


class WLEC(BaseMetric):
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""

def __init__(self, dataset: KGDataset) -> None:
super().__init__(dataset)

def compute(
self,
H: int = 3,
*,
return_full: bool = False,
progress: bool = True,
) -> float | Tuple[float, List[float]]:
"""Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC).

Parameters
----------
dataset:
Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``),
``num_entities`` and ``num_relations``.
H:
Number of refinement iterations **after** the initial colouring (so the
colours are updated **H** times and the entropy is measured **H+1** times).
return_full:
If *True*, additionally return the list of entropies for each depth.
progress:
Whether to print a mini progress bar (no external dependency).

Returns
-------
average_entropy or (average_entropy, entropies)
Average of the stored entropies, and optionally the list itself.
"""
n = int(self.num_entities)
if n == 0:
msg = "Dataset appears to contain zero entities."
raise ValueError(msg)

# Colour assignment - we keep it as a *list* of ints for cheap hashing.
# Start with colour 1 for every entity (index 0 unused).
colour: List[int] = [1] * n

entropies: List[float] = []

# Optional poor-man's progress bar.
def _print_prog(i: int) -> None:
if progress:
print(f"\rIteration {i}/{H}", end="", flush=True)

for h in range(H + 1):
_print_prog(h)

# ------------------------------------------------------------------
# Step 1: measure entropy of current colouring ----------------------
# ------------------------------------------------------------------
freq = Counter(colour)
ent = _entropy_from_counter(freq, n)
entropies.append(ent)

if h == H:
break # We have reached the requested depth - stop here.

# ------------------------------------------------------------------
# Step 2: create refined colours -----------------------------------
# ------------------------------------------------------------------
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
next_colour: List[int] = [0] * n

for v in range(n):
# Build the multiset T containing outgoing and incoming features.
T: List[Tuple[int, int, int]] = []
for r, t in self.out_adj[v]:
T.append((r, 0, colour[t])) # 0 = outgoing (\u2193)
for r, h_ in self.in_adj[v]:
T.append((r, 1, colour[h_])) # 1 = incoming (\u2191)
T.sort() # canonical ordering

key = (colour[v], tuple(T))
if key not in bucket:
bucket[key] = len(bucket) + 1 # new colour ID (start at 1)
next_colour[v] = bucket[key]

colour = next_colour # move to next iteration

_print_prog(H)
if progress:
print() # newline after progress bar

average_entropy = sum(entropies) / len(entropies)

c_ratio, c_nwlec = _normalise_scores(entropies, n)

if return_full:
return average_entropy, entropies, c_ratio, c_nwlec

return average_entropy, c_ratio, c_nwlec

+ 0
- 0
models/__init__.py View File


+ 127
- 0
models/base_model.py View File

@@ -0,0 +1,127 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn
from torch.nn.functional import log_softmax, softmax

if TYPE_CHECKING:
from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target


class ERModel(nn.Module, ABC):
"""Base class for knowledge graph models."""

def __init__(self, num_entities: int, num_relations: int, dim: int) -> None:
super().__init__()
self.num_entities = num_entities
self.num_relations = num_relations
self.dim = dim

self.inner_model = None

@property
def device(self) -> torch.device:
"""Return the device of the model."""
return next(self.inner_model.parameters()).device

@abstractmethod
def reset_parameters(self, *args, **kwargs) -> None:
"""Reset the parameters of the model."""

def predict(
self,
hrt_batch: MappedTriples,
target: Target,
full_batch: bool = True,
ids: LongTensor | None = None,
**kwargs,
) -> FloatTensor:
if not self.inner_model:
msg = (
"Inner model is not set. Please initialize the inner model before calling predict."
)
raise ValueError(
msg,
)

return self.inner_model.predict(
hrt_batch=hrt_batch,
target=target,
full_batch=full_batch,
ids=ids,
**kwargs,
)

def forward(
self,
triples: MappedTriples,
slice_size: int | None = None,
slice_dim: int = 0,
*,
mode: InductiveMode | None,
) -> FloatTensor:
h_indices = triples[:, 0]
r_indices = triples[:, 1]
t_indices = triples[:, 2]

if not self.inner_model:
msg = (
"Inner model is not set. Please initialize the inner model before calling forward."
)
raise ValueError(
msg,
)

return self.inner_model.forward(
h_indices=h_indices,
r_indices=r_indices,
t_indices=t_indices,
slice_size=slice_size,
slice_dim=slice_dim,
mode=mode,
)

@torch.inference_mode()
def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r)
self,
hr_batch: LongTensor, # (B, 2) : (h,r)
*,
slice_size: int | None = None,
mode: InductiveMode | None = None,
log: bool = False,
) -> FloatTensor: # (B, |E|)
scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode)
return log_softmax(scores, -1) if log else softmax(scores, -1)

# ----------------------------------------------------------------------- #
# (1) HEAD-given-(relation, tail) → pθ(h | r,t) #
# ----------------------------------------------------------------------- #
@torch.inference_mode()
def head_distribution(
self,
rt_batch: LongTensor, # (B, 2) : (r,t)
*,
slice_size: int | None = None,
mode: InductiveMode | None = None,
log: bool = False,
) -> FloatTensor: # (B, |E|)
scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode)
return log_softmax(scores, -1) if log else softmax(scores, -1)

# ----------------------------------------------------------------------- #
# (2) RELATION-given-(head, tail) → pθ(r | h,t) #
# ----------------------------------------------------------------------- #
@torch.inference_mode()
def relation_distribution(
self,
ht_batch: LongTensor, # (B, 2) : (h,t)
*,
slice_size: int | None = None,
mode: InductiveMode | None = None,
log: bool = False,
) -> FloatTensor: # (B, |R|)
scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode)
return log_softmax(scores, -1) if log else softmax(scores, -1)

+ 0
- 0
models/translation/__init__.py View File


BIN
models/translation/__pycache__/__init__.cpython-310.pyc View File


BIN
models/translation/__pycache__/trans_e.cpython-310.pyc View File


+ 70
- 0
models/translation/trans_e.py View File

@@ -0,0 +1,70 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from pykeen.models import TransE as PyKEENTransE
from torch import nn

from models.base_model import ERModel

if TYPE_CHECKING:
from pykeen.typing import MappedTriples


class TransE(ERModel):
"""
TransE model for knowledge graph embedding.

score = ||h + r - t||
"""

def __init__(
self,
num_entities: int,
num_relations: int,
triples_factory: MappedTriples,
dim: int = 200,
sf_norm: int = 1,
p_norm: bool = True,
margin: float | None = None,
epsilon: float | None = None,
device: torch.device = torch.device("cpu"),
) -> None:
super().__init__(num_entities, num_relations, dim)
self.dim = dim
self.sf_norm = sf_norm
self.p_norm = p_norm
self.margin = margin
self.epsilon = epsilon

self.inner_model = PyKEENTransE(
triples_factory=triples_factory,
embedding_dim=dim,
scoring_fct_norm=sf_norm,
power_norm=p_norm,
random_seed=42,
).to(device)

def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
# Parameter initialization
if margin is None or epsilon is None:
# If no margin/epsilon are provided, use Xavier uniform initialization
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
self.embedding_range = nn.Parameter(
torch.Tensor([(margin + epsilon) / self.dim]),
requires_grad=False,
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)

+ 76
- 0
models/translation/trans_h.py View File

@@ -0,0 +1,76 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from pykeen.losses import MarginRankingLoss
from pykeen.models import TransH as PyKEENTransH
from pykeen.regularizers import LpRegularizer
from torch import nn

from models.base_model import ERModel

if TYPE_CHECKING:
from pykeen.typing import MappedTriples


class TransH(ERModel):
"""
TransH model for knowledge graph embedding.
"""

def __init__(
self,
num_entities: int,
num_relations: int,
triples_factory: MappedTriples,
dim: int = 200,
p_norm: bool = True,
margin: float | None = None,
epsilon: float | None = None,
device: torch.device = torch.device("cpu"),
**kwargs,
) -> None:
super().__init__(num_entities, num_relations, dim)
self.dim = dim
self.p_norm = p_norm
self.margin = margin
self.epsilon = epsilon

self.inner_model = PyKEENTransH(
triples_factory=triples_factory,
embedding_dim=dim,
scoring_fct_norm=1,
loss=self.loss,
power_norm=p_norm,
entity_initializer="xavier_uniform_", # good default
relation_initializer="xavier_uniform_",
random_seed=42,
).to(device)

@property
def loss(self) -> torch.Tensor:
return MarginRankingLoss(margin=self.margin, reduction="mean")

def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
# Parameter initialization
if margin is None or epsilon is None:
# If no margin/epsilon are provided, use Xavier uniform initialization
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
self.embedding_range = nn.Parameter(
torch.Tensor([(margin + epsilon) / self.dim]),
requires_grad=False,
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)

+ 79
- 0
models/translation/trans_r.py View File

@@ -0,0 +1,79 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from pykeen.losses import MarginRankingLoss
from pykeen.models import TransR as PyKEENTransR
from torch import nn

from models.base_model import ERModel

if TYPE_CHECKING:
from pykeen.typing import MappedTriples


class TransR(ERModel):
"""
TransR model for knowledge graph embedding.
"""

def __init__(
self,
num_entities: int,
num_relations: int,
triples_factory: MappedTriples,
entity_dim: int = 200,
relation_dim: int = 30,
p_norm: bool = True,
p_norm_value: int = 2,
margin: float | None = None,
epsilon: float | None = None,
device: torch.device = torch.device("cpu"),
**kwargs,
) -> None:
super().__init__(num_entities, num_relations, entity_dim)
self.entity_dim = entity_dim
self.relation_dim = relation_dim
self.p_norm = p_norm
self.margin = margin
self.epsilon = epsilon

self.inner_model = PyKEENTransR(
triples_factory=triples_factory,
embedding_dim=entity_dim,
relation_dim=relation_dim,
scoring_fct_norm=1,
loss=self.loss,
power_norm=p_norm,
entity_initializer="xavier_uniform_",
relation_initializer="xavier_uniform_",
random_seed=42,
).to(device)

@property
def loss(self) -> torch.Tensor:
return MarginRankingLoss(margin=self.margin, reduction="mean")

def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
# Parameter initialization
if margin is None or epsilon is None:
# If no margin/epsilon are provided, use Xavier uniform initialization
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
self.embedding_range = nn.Parameter(
torch.Tensor([(margin + epsilon) / self.dim]),
requires_grad=False,
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)

+ 82
- 0
pyproject.toml View File

@@ -0,0 +1,82 @@
[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
respect-gitignore = true

# Black line length is 88
line-length = 100
indent-width = 4

[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["ALL"]
ignore = ["ANN002", "ANN003", "RET504", "ERA001", "E741", "N812", "D", "PLR0913", "T201", "FBT001", "FBT002", "PLW0602", "PLW0603", "PTH", "C901", "PLR2004", "UP006", "UP035", "G004", "N803", "N806", "ARG002"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E", "F", "I", "N"]

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

# Allow unused variables when underscore-prefixed.
#dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = false

# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

[tool.pyright]
typeCheckingMode = "off"

+ 28
- 0
setup/install.sh View File

@@ -0,0 +1,28 @@
#!/usr/bin/env bash
# install.sh – minimal Conda bootstrap with ~/.bashrc check
set -e

ENV=kg-env # must match your environment.yml
YAML=kg_env.yaml # assumed to be in the current dir

# ── 1. try to load an existing conda from ~/.bashrc ──────────────────────
[ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" || true

# ── 2. ensure conda exists ───────────────────────────────────────────────
if ! command -v conda &>/dev/null; then
echo "[install.sh] Conda not found → installing Miniconda."
curl -fsSL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh
bash miniconda.sh -b -p "$HOME/miniconda3"
rm miniconda.sh
source "$HOME/miniconda3/etc/profile.d/conda.sh"
conda init bash >/dev/null # so future shells have it
else
# conda exists; load its helper for this shell
source "$(conda info --base)/etc/profile.d/conda.sh"
fi

# ── 3. create or update the project env ──────────────────────────────────
conda env create -n "$ENV" -f "$YAML" \
|| conda env update -n "$ENV" -f "$YAML" --prune

echo "✔ Done. Activate with: conda activate $ENV"

+ 163
- 0
setup/kg_env.yaml View File

@@ -0,0 +1,163 @@
# environment.yml ── KG-Embeddings project
name: kg
channels:
- pytorch
- defaults
- nvidia
- conda-forge
- https://repo.anaconda.com/pkgs/main
- https://repo.anaconda.com/pkgs/r
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- aom=3.6.1=h59595ed_0
- blas=1.0=mkl
- brotli-python=1.1.0=py311hfdbb021_2
- bzip2=1.0.8=h4bc722e_7
- ca-certificates=2024.12.14=hbcca054_0
- certifi=2024.12.14=pyhd8ed1ab_0
- cffi=1.17.1=py311hf29c0ef_0
- charset-normalizer=3.4.1=pyhd8ed1ab_0
- cpython=3.11.11=py311hd8ed1ab_1
- cuda-cudart=12.4.127=0
- cuda-cupti=12.4.127=0
- cuda-libraries=12.4.1=0
- cuda-nvrtc=12.4.127=0
- cuda-nvtx=12.4.127=0
- cuda-opencl=12.6.77=0
- cuda-runtime=12.4.1=0
- cuda-version=12.6=3
- ffmpeg=4.4.2=gpl_hdf48244_113
- filelock=3.16.1=pyhd8ed1ab_1
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=h77eed37_3
- fontconfig=2.15.0=h7e30c49_1
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- freetype=2.12.1=h267a509_2
- gettext=0.22.5=he02047a_3
- gettext-tools=0.22.5=he02047a_3
- giflib=5.2.2=hd590300_0
- gmp=6.3.0=hac33072_2
- gmpy2=2.1.5=py311h0f6cedb_3
- gnutls=3.7.9=hb077bed_0
- h2=4.1.0=pyhd8ed1ab_1
- hpack=4.0.0=pyhd8ed1ab_1
- hyperframe=6.0.1=pyhd8ed1ab_1
- idna=3.10=pyhd8ed1ab_1
- intel-openmp=2022.0.1=h06a4308_3633
- jinja2=3.1.5=pyhd8ed1ab_0
- lame=3.100=h166bdaf_1003
- lcms2=2.16=hb7c19ff_0
- ld_impl_linux-64=2.43=h712a8e2_2
- lerc=4.0.0=h27087fc_0
- libasprintf=0.22.5=he8f35ee_3
- libasprintf-devel=0.22.5=he8f35ee_3
- libblas=3.9.0=16_linux64_mkl
- libcblas=3.9.0=16_linux64_mkl
- libcublas=12.4.5.8=0
- libcufft=11.2.1.3=0
- libcufile=1.11.1.6=0
- libcurand=10.3.7.77=0
- libcusolver=11.6.1.9=0
- libcusparse=12.3.1.170=0
- libdeflate=1.23=h4ddbbb0_0
- libdrm=2.4.124=hb9d3cd8_0
- libegl=1.7.0=ha4b6fd6_2
- libexpat=2.6.4=h5888daf_0
- libffi=3.4.2=h7f98852_5
- libgcc=14.2.0=h77fa898_1
- libgcc-ng=14.2.0=h69a702a_1
- libgettextpo=0.22.5=he02047a_3
- libgettextpo-devel=0.22.5=he02047a_3
- libgl=1.7.0=ha4b6fd6_2
- libglvnd=1.7.0=ha4b6fd6_2
- libglx=1.7.0=ha4b6fd6_2
- libgomp=14.2.0=h77fa898_1
- libiconv=1.17=hd590300_2
- libidn2=2.3.7=hd590300_0
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=16_linux64_mkl
- liblzma=5.6.3=hb9d3cd8_1
- libnpp=12.2.5.30=0
- libnsl=2.0.1=hd590300_0
- libnvfatbin=12.6.77=0
- libnvjitlink=12.4.127=0
- libnvjpeg=12.3.1.117=0
- libpciaccess=0.18=hd590300_0
- libpng=1.6.45=h943b412_0
- libsqlite=3.47.2=hee588c1_0
- libstdcxx=14.2.0=hc0a3c3a_1
- libstdcxx-ng=14.2.0=h4852527_1
- libtasn1=4.19.0=h166bdaf_0
- libtiff=4.7.0=hd9ff511_3
- libunistring=0.9.10=h7f98852_0
- libuuid=2.38.1=h0b41bf4_0
- libva=2.22.0=h8a09558_1
- libvpx=1.13.1=h59595ed_0
- libwebp=1.5.0=hae8dbeb_0
- libwebp-base=1.5.0=h851e524_0
- libxcb=1.17.0=h8a09558_0
- libxcrypt=4.4.36=hd590300_1
- libxml2=2.13.5=h0d44e9d_1
- libzlib=1.3.1=hb9d3cd8_2
- llvm-openmp=15.0.7=h0cdce71_0
- markupsafe=3.0.2=py311h2dc5d0c_1
- mkl=2022.1.0=hc2b9512_224
- mpc=1.3.1=h24ddda3_1
- mpfr=4.2.1=h90cbb55_3
- mpmath=1.3.0=pyhd8ed1ab_1
- ncurses=6.5=h2d0b736_2
- nettle=3.9.1=h7ab15ed_0
- networkx=3.4.2=pyh267e887_2
- numpy=2.2.1=py311hf916aec_0
- openh264=2.3.1=hcb278e6_2
- openjpeg=2.5.3=h5fbd93e_0
- openssl=3.4.0=h7b32b05_1
- p11-kit=0.24.1=hc5aa10d_0
- pillow=11.1.0=py311h1322bbf_0
- pip=24.3.1=pyh8b19718_2
- pthread-stubs=0.4=hb9d3cd8_1002
- pycparser=2.22=pyh29332c3_1
- pysocks=1.7.1=pyha55dd90_7
- python=3.11.11=h9e4cc4f_1_cpython
- python_abi=3.11=5_cp311
- pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0
- pytorch-cuda=12.4=hc786d27_7
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.2=py311h9ecbd09_1
- readline=8.2=h8228510_1
- requests=2.32.3=pyhd8ed1ab_1
- setuptools=75.8.0=pyhff2d567_0
- svt-av1=1.4.1=hcb278e6_0
- sympy=1.13.3=pyh2585a3b_105
- tk=8.6.13=noxft_h4845f30_101
- torchaudio=2.5.1=py311_cu124
- torchtriton=3.1.0=py311
- torchvision=0.20.1=py311_cu124
- typing_extensions=4.12.2=pyha770c72_1
- tzdata=2024b=hc8b5060_0
- urllib3=2.3.0=pyhd8ed1ab_0
- wayland=1.23.1=h3e06ad9_0
- wayland-protocols=1.37=hd8ed1ab_0
- wheel=0.45.1=pyhd8ed1ab_1
- x264=1!164.3095=h166bdaf_2
- x265=3.5=h924138e_3
- xorg-libx11=1.8.10=h4f16b4b_1
- xorg-libxau=1.0.12=hb9d3cd8_0
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
- xorg-libxext=1.3.6=hb9d3cd8_0
- xorg-libxfixes=6.0.1=hb9d3cd8_0
- yaml=0.2.5=h7f98852_2
- zstandard=0.23.0=py311hbc35293_1
- zstd=1.5.6=ha6fb4c9_0
- pip:
- torchmetrics>=1.4 # nicer MRR/Hits@K utilities
- lovely-tensors>=0.1.0 # tensor utilities
- lovely-numpy>=0.1.0 # numpy utilities
- einops==0.8.1
- pykeen==1.11.1
- hydra-core==1.3.2
- tensorboard==2.19.0

+ 6
- 0
tools/__init__.py View File

@@ -0,0 +1,6 @@
from .pretty_logger import get_pretty_logger
from .sampling import DeterministicRandomSampler
from .utils import set_seed
from .tb_handler import TensorBoardHandler
from .params import CommonParams, TrainingParams
from .checkpoint_manager import CheckpointManager

+ 221
- 0
tools/checkpoint_manager.py View File

@@ -0,0 +1,221 @@
from __future__ import annotations

import re
from pathlib import Path


class CheckpointManager:
"""
A checkpoint manager that handles checkpoint directory creation and path resolution.

Features:
- Creates sequentially numbered checkpoint directories
- Supports two loading modes: by components or by full path
"""

def __init__(
self,
root_directory: str | Path,
run_name: str,
load_only: bool = False,
) -> None:
"""
Initialize the checkpoint manager.

Args:
root_directory: Root directory for checkpoints
run_name: Name of the run (used as suffix for checkpoint directories)
"""
self.root_directory = Path(root_directory)
self.run_name = run_name
self.checkpoint_directory = None

if not load_only:
self.root_directory.mkdir(parents=True, exist_ok=True)
self.checkpoint_directory = self._create_checkpoint_directory()
else:
self.checkpoint_directory = ""

def _find_existing_directories(self) -> list[int]:
"""
Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number.

Returns:
List of existing sequence numbers
"""
pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$")
existing_numbers = []

if self.root_directory.exists():
for item in self.root_directory.iterdir():
if item.is_dir():
match = pattern.match(item.name)
if match:
existing_numbers.append(int(match.group(1)))

return sorted(existing_numbers)

def _create_checkpoint_directory(self) -> Path:
"""
Create a new checkpoint directory with the next sequential number.

Returns:
Path to the created checkpoint directory
"""
existing_numbers = self._find_existing_directories()

# Determine the next number
next_number = max(existing_numbers) + 1 if existing_numbers else 1

# Create directory name with 3-digit zero-padded number
dir_name = f"{next_number:03d}_{self.run_name}"
checkpoint_dir = self.root_directory / dir_name
self.run_name = dir_name

# Create the directory
checkpoint_dir.mkdir(parents=True, exist_ok=True)

return checkpoint_dir

def get_checkpoint_directory(self) -> Path:
"""
Get the current checkpoint directory.

Returns:
Path to the checkpoint directory
"""
return self.checkpoint_directory

def get_model_fpath(self, model_path: str) -> Path:
"""
Get the full path to a model checkpoint file.

Args:
model_path: Either a tuple (model_id, iteration) or a string path

Returns:
Full path to the model checkpoint file
"""

try:
model_path = eval(model_path)
if isinstance(model_path, tuple):
model_id, iteration = model_path
checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}"
filename = f"model-{iteration}.pt"
return Path(f"{checkpoint_directory}/{filename}")
except (SyntaxError, NameError):
pass

if isinstance(model_path, str):
return Path(model_path)
msg = "model_path must be a tuple (model_id, iteration) or a string path"
raise ValueError(msg)

def get_model_path_by_args(
self,
model_id: str | None = None,
iteration: int | str | None = None,
full_path: str | Path | None = None,
) -> Path:
"""
Get the path for loading a model checkpoint.

Two modes of operation:
1. Component mode: Provide model_id and iteration to construct path
2. Full path mode: Provide full_path directly

Args:
model_id: Model identifier (used in component mode)
iteration: Training iteration (used in component mode)
full_path: Full path to checkpoint (used in full path mode)

Returns:
Path to the checkpoint file

Raises:
ValueError: If neither component parameters nor full_path are provided,
or if both modes are attempted simultaneously
"""
# Check which mode we're operating in
component_mode = model_id is not None or iteration is not None
full_path_mode = full_path is not None

if component_mode and full_path_mode:
msg = (
"Cannot use both component mode (model_id/iteration) "
"and full path mode simultaneously"
)
raise ValueError(msg)

if not component_mode and not full_path_mode:
msg = "Must provide either (model_id and iteration) or full_path"
raise ValueError(msg)

if full_path_mode:
# Full path mode: return the path as-is
return Path(full_path)

# Component mode: construct path from checkpoint directory, model_id, and iteration
if model_id is None or iteration is None:
msg = "Both model_id and iteration must be provided in component mode"
raise ValueError(msg)

filename = f"{model_id}_iter_{iteration}.pt"
return self.checkpoint_directory / filename

def save_checkpoint_info(self, info: dict) -> None:
"""
Save checkpoint information to a JSON file in the checkpoint directory.

Args:
info: Dictionary containing checkpoint metadata
"""
import json

info_file = self.checkpoint_directory / "checkpoint_info.json"
with open(info_file, "w") as f:
json.dump(info, f, indent=2)

def load_checkpoint_info(self) -> dict:
"""
Load checkpoint information from the checkpoint directory.

Returns:
Dictionary containing checkpoint metadata
"""
import json

info_file = self.checkpoint_directory / "checkpoint_info.json"
if info_file.exists():
with open(info_file) as f:
return json.load(f)
return {}

def __str__(self) -> str:
return (
f"CheckpointManager(root='{self.root_directory}', "
f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')"
)

def __repr__(self) -> str:
return self.__str__()


# Example usage:
if __name__ == "__main__":
import tempfile

# Create checkpoint manager with proper temp directory
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_manager = CheckpointManager(temp_dir, "my_experiment")
print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}")

# Component mode - construct path from components
model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000)
print(f"Model path (component mode): {model_path}")

# Full path mode - use existing full path
full_path = "/some/other/path/model.pt"
model_path = ckpt_manager.get_model_path_by_args(full_path=full_path)
print(f"Model path (full path mode): {model_path}")

+ 45
- 0
tools/params.py View File

@@ -0,0 +1,45 @@
from __future__ import annotations

from dataclasses import dataclass


@dataclass
class CommonParams:
"""
Common parameters for the application.
"""

model_name: str
project_name: str
run_name: str
save_dpath: str
save_every: int = 1000
load_path: str | None = ""
log_dir: str = "./runs"
log_every: int = 10
log_console_every: int = 100
evaluate_only: bool = False


@dataclass
class TrainingParams:
"""
Parameters for training.
"""

lr: float = 0.001
min_lr: float = 0.0001
weight_decay: float = 0.0
t0: int = 50
lr_step: int = 10
gamma: float = 1.0
batch_size: int = 128
num_workers: int = 0
seed: int = 42

num_train_steps: int = 100
num_epochs: int = 100
eval_every: int = 5

validation_sample_size: int = 1000
validation_batch_size: int = 128

+ 417
- 0
tools/pretty_logger.py View File

@@ -0,0 +1,417 @@
# pretty_logger.py

import json
import logging
import logging.handlers
import sys
import threading
import time
from contextlib import ContextDecorator
from functools import wraps

# ANSI escape codes for colors
RESET = "\x1b[0m"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = (
"\x1b[30m",
"\x1b[31m",
"\x1b[32m",
"\x1b[33m",
"\x1b[34m",
"\x1b[35m",
"\x1b[36m",
"\x1b[37m",
)

_LEVEL_COLORS = {
logging.DEBUG: CYAN,
logging.INFO: GREEN,
logging.WARNING: YELLOW,
logging.ERROR: RED,
logging.CRITICAL: MAGENTA,
}


def _supports_color() -> bool:
"""
Detect whether the running terminal supports ANSI colors.
"""
if hasattr(sys.stdout, "isatty") and sys.stdout.isatty():
platform = sys.platform
if platform == "win32":
# On Windows, modern terminals may support ANSI; assume yes for Windows 10+
return True
return True
return False


class PrettyFormatter(logging.Formatter):
"""
A logging.Formatter that outputs colorized logs to the console
(if enabled) and supports JSON formatting (if requested).
"""

def __init__(
self,
fmt: str = None,
datefmt: str = "%Y-%m-%d %H:%M:%S",
use_colors: bool = True,
json_output: bool = False,
):
super().__init__(fmt=fmt, datefmt=datefmt)
self.use_colors = use_colors and _supports_color()
self.json_output = json_output

def format(self, record: logging.LogRecord) -> str:
# If JSON output is requested, dump a dict of relevant fields
if self.json_output:
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"name": record.name,
"level": record.levelname,
"message": record.getMessage(),
}
# Include extra/contextual fields
for k, v in record.__dict__.get("extra_fields", {}).items():
log_record[k] = v
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
return json.dumps(log_record, ensure_ascii=False)

# Otherwise, build a colorized/“pretty” line
level_color = _LEVEL_COLORS.get(record.levelno, WHITE) if self.use_colors else ""
reset = RESET if self.use_colors else ""
timestamp = self.formatTime(record, self.datefmt)
name = record.name
level = record.levelname

# Basic format: [timestamp] [LEVEL] [name]: message
base = f"[{timestamp}] [{level}] [{name}]: {record.getMessage()}"

# If there are any extra fields attached via LoggerAdapter, append them
extra_fields = record.__dict__.get("extra_fields", {})
if extra_fields:
# key1=val1 key2=val2 …
extras_str = " ".join(f"{k}={v}" for k, v in extra_fields.items())
base = f"{base} :: {extras_str}"

# If exception info is present, include traceback
if record.exc_info:
exc_text = self.formatException(record.exc_info)
base = f"{base}\n{exc_text}"

if self.use_colors:
return f"{level_color}{base}{reset}"
else:
return base


class JsonFormatter(PrettyFormatter):
"""
Shortcut to force JSON formatting (ignores color flags).
"""

def __init__(self, datefmt: str = "%Y-%m-%d %H:%M:%S"):
super().__init__(fmt=None, datefmt=datefmt, use_colors=False, json_output=True)


class ContextFilter(logging.Filter):
"""
Inject extra_fields from a LoggerAdapter into the LogRecord, so the Formatter can find them.
"""

def filter(self, record: logging.LogRecord) -> bool:
extra = getattr(record, "extra", None)
if isinstance(extra, dict):
record.extra_fields = extra
else:
record.extra_fields = {}
return True


class Timer(ContextDecorator):
"""
Context manager & decorator to measure execution time of a code block or function.
Usage as context manager:
with Timer(logger, "block name"):
# code …
Usage as decorator:
@Timer(logger, "function foo")
def foo(...):
...
"""

def __init__(self, logger: logging.Logger, name: str = None, level: int = logging.INFO):
self.logger = logger
self.name = name or ""
self.level = level

def __enter__(self):
self.start_time = time.time()
return self

def __exit__(self, exc_type, exc, exc_tb):
elapsed = time.time() - self.start_time
label = f"[Timer:{self.name}]" if self.name else "[Timer]"
self.logger.log(self.level, f"{label} elapsed {elapsed:.4f} sec")
return False # do not suppress exceptions


def log_function(logger: logging.Logger, level: int = logging.DEBUG):
"""
Decorator factory to log function entry/exit and execution time.
Usage:
@log_function(logger, level=logging.INFO)
def my_func(...):
...
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
logger.log(level, f"➤ Entering {func.__name__}()")
start = time.time()
try:
result = func(*args, **kwargs)
except Exception:
logger.exception(f"✖ Exception in {func.__name__}()")
raise
elapsed = time.time() - start
logger.log(level, f"✔ Exiting {func.__name__}(), elapsed {elapsed:.4f} sec")
return result

return wrapper

return decorator


class PrettyLogger:
"""
Encapsulates a logging.Logger with “pretty” formatting, colorized console output,
rotating file handler, JSON option, contextual fields, etc.
"""

_instances = {}
_lock = threading.Lock()

def __new__(cls, name: str = "root", **kwargs):
"""
Implements a simple singleton: multiple calls with the same name return the same instance.
"""
with cls._lock:
if name not in cls._instances:
instance = super().__new__(cls)
cls._instances[name] = instance
return cls._instances[name]

def __init__(
self,
name: str = "root",
level: int = logging.DEBUG,
log_to_console: bool = True,
console_level: int = logging.DEBUG,
log_to_file: bool = False,
log_file: str = "app.log",
file_level: int = logging.INFO,
max_bytes: int = 10 * 1024 * 1024, # 10 MB
backup_count: int = 5,
formatter: PrettyFormatter = None,
json_output: bool = False,
):
"""
Initialize (or re‐configure) a PrettyLogger.

Args:
name: logger name.
level: root level for the logger (lowest level that will be processed).
log_to_console: whether to attach a console (StreamHandler).
console_level: level for console output.
log_to_file: whether to attach a rotating file handler.
log_file: path to the log file.
file_level: level for file output.
max_bytes: rotation threshold in bytes.
backup_count: number of backup files to keep.
formatter: an instance of PrettyFormatter. If None, a default is created.
json_output: if True, forces JSON formatting on all handlers.
"""
# Avoid re‐initializing if already done
logger = logging.getLogger(name)
if getattr(logger, "_pretty_logger_inited", False):
return

self.logger = logger
self.logger.setLevel(level)
self.logger.propagate = False # avoid double‐logging if root logger is configured elsewhere

# Create a default formatter if not supplied
if formatter is None:
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
formatter = PrettyFormatter(
fmt=fmt, use_colors=not json_output, json_output=json_output
)

# Add ContextFilter so extra_fields is always present
context_filter = ContextFilter()
self.logger.addFilter(context_filter)

# Console handler
if log_to_console:
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(console_level)
ch.setFormatter(formatter)
self.logger.addHandler(ch)

# Rotating file handler
if log_to_file:
fh = logging.handlers.RotatingFileHandler(
filename=log_file,
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8",
)
fh.setLevel(file_level)
# For file, typically we don’t colorize
file_formatter = (
JsonFormatter(datefmt="%Y-%m-%d %H:%M:%S")
if json_output
else PrettyFormatter(fmt=fmt, use_colors=False, json_output=False)
)
fh.setFormatter(file_formatter)
self.logger.addHandler(fh)

# Mark as initialized
setattr(self.logger, "_pretty_logger_inited", True)

def addHandler(self, handler: logging.Handler):
"""
Add a custom handler to the logger.
This can be used to add any logging.Handler instance.
"""
self.logger.addHandler(handler)

def bind(self, **kwargs):
"""
Return a LoggerAdapter that automatically injects extra fields into all log records.
Example:
ctx_logger = pretty_logger.bind(request_id=123, user="alice")
ctx_logger.info("Something happened") # will include request_id and user in the output
"""
return logging.LoggerAdapter(self.logger, {"extra": kwargs})

def add_stream_handler(
self, stream=None, level: int = logging.DEBUG, formatter: PrettyFormatter = None
):
"""
Add an additional stream handler (e.g. to stderr).
"""
handler = logging.StreamHandler(stream or sys.stdout)
handler.setLevel(level)
if formatter is None:
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
formatter = PrettyFormatter(fmt=fmt)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
return handler

def add_rotating_file_handler(
self,
log_file: str,
level: int = logging.INFO,
max_bytes: int = 10 * 1024 * 1024,
backup_count: int = 5,
json_output: bool = False,
):
"""
Add another rotating file handler at runtime.
"""
fh = logging.handlers.RotatingFileHandler(
filename=log_file,
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8",
)
fh.setLevel(level)
if json_output:
formatter = JsonFormatter()
else:
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
formatter = PrettyFormatter(fmt=fmt, use_colors=False, json_output=False)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
return fh

def remove_handler(self, handler: logging.Handler):
"""
Remove a handler from the logger.
"""
self.logger.removeHandler(handler)

def set_level(self, level: int):
"""
Change the logger’s root level at runtime.
"""
self.logger.setLevel(level)

def get_logger(self) -> logging.Logger:
return self.logger

# Convenience methods to expose common logging calls directly from this object:

def debug(self, msg, *args, **kwargs):
self.logger.debug(msg, *args, **kwargs)

def info(self, msg, *args, **kwargs):
self.logger.info(msg, *args, **kwargs)

def warning(self, msg, *args, **kwargs):
self.logger.warning(msg, *args, **kwargs)

def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)

def critical(self, msg, *args, **kwargs):
self.logger.critical(msg, *args, **kwargs)

def exception(self, msg, *args, exc_info=True, **kwargs):
"""
Log an error message with exception traceback. By default exc_info=True.
"""
self.logger.error(msg, *args, exc_info=exc_info, **kwargs)


# Convenience function for a module‐level “global” pretty logger:
_global_loggers = {}
_global_lock = threading.Lock()


def get_pretty_logger(
name: str = "root",
level: int = logging.DEBUG,
log_to_console: bool = True,
console_level: int = logging.DEBUG,
log_to_file: bool = False,
log_file: str = "app.log",
file_level: int = logging.INFO,
max_bytes: int = 10 * 1024 * 1024,
backup_count: int = 5,
json_output: bool = False,
) -> PrettyLogger:
"""
Return a PrettyLogger instance for `name`, creating/configuring it on first use.
Subsequent calls with the same `name` will return the same logger (singleton‐style).
"""
with _global_lock:
if name not in _global_loggers:
pl = PrettyLogger(
name=name,
level=level,
log_to_console=log_to_console,
console_level=console_level,
log_to_file=log_to_file,
log_file=log_file,
file_level=file_level,
max_bytes=max_bytes,
backup_count=backup_count,
json_output=json_output,
)
_global_loggers[name] = pl
return _global_loggers[name]

+ 66
- 0
tools/sampling.py View File

@@ -0,0 +1,66 @@
import torch
import numpy as np
from torch.utils.data.sampler import Sampler


def negative_sampling(
pos: torch.Tensor,
num_entities: int,
num_negatives: int = 1,
mode: list = ["head", "tail"],
) -> torch.Tensor:
pos = pos.repeat_interleave(num_negatives, dim=0)
if "head" in mode:
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
pos[:, 0] = neg
if "tail" in mode:
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
pos[:, 2] = neg
return pos


class DeterministicRandomSampler(Sampler):
"""
A deterministic random sampler that selects a fixed number of samples
in a reproducible manner using a given seed.
"""

def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None:
"""
Args:
dataset (Dataset): PyTorch dataset.
sample_size (int, optional): Number of samples to draw. If None, use full dataset.
seed (int): Seed for reproducibility.
"""
self.dataset_size = dataset_size
self.seed = seed
self.sample_size = sample_size if sample_size != 0 else self.dataset_size

# Ensure sample size is within dataset size
if self.sample_size > self.dataset_size:
msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}."
raise ValueError(
msg,
)

self.indices = self._generate_deterministic_indices()

def _generate_deterministic_indices(self) -> list[int]:
"""
Generates a fixed random subset of indices using the given seed.
"""
rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility
all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices
return all_indices[: self.sample_size].tolist() # Select only the desired number of samples

def __iter__(self) -> iter:
"""
Yields the shuffled dataset indices.
"""
return iter(self.indices)

def __len__(self) -> int:
"""
Returns the total number of samples drawn.
"""
return self.sample_size

+ 21
- 0
tools/tb_handler.py View File

@@ -0,0 +1,21 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from torch.utils.tensorboard import SummaryWriter


class TensorBoardHandler(logging.Handler):
"""Convert each log record into a TensorBoard text summary."""

def __init__(self, writer: SummaryWriter, tag: str = "logs"):
super().__init__(level=logging.INFO)
self.writer = writer
self.tag = tag
self._step = 0 # will be incremented every emit

def emit(self, record: logging.LogRecord) -> None:
self.writer.add_text(self.tag, self.format(record), global_step=self._step)
self._step += 1

+ 12
- 0
tools/train.py View File

@@ -0,0 +1,12 @@
from torch.optim.lr_scheduler import StepLR


class StepMinLR(StepLR):
def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1):
self.min_lr = min_lr
super().__init__(optimizer, step_size, gamma, last_epoch)

def get_lr(self):
# use StepLR's formula, then clamp
raw_lrs = super().get_lr()
return [max(lr, self.min_lr) for lr in raw_lrs]

+ 25
- 0
tools/utils.py View File

@@ -0,0 +1,25 @@
import random
from enum import Enum

import numpy as np
import torch


def set_seed(seed: int = 42) -> None:
"""
Set the random seed for reproducibility.

Args:
seed (int): The seed value to set for random number generation.
"""
torch.manual_seed(seed)
random.seed(seed)
np.random.default_rng(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


class Target(str, Enum): # ↔ PyKEEN's LABEL_* constants
HEAD = "head"
TAIL = "tail"
RELATION = "relation"

+ 0
- 0
training/__init__.py View File


+ 0
- 0
training/model_trainers/__init__.py View File


BIN
training/model_trainers/__pycache__/__init__.cpython-310.pyc View File


BIN
training/model_trainers/__pycache__/model_trainer_base.cpython-310.pyc View File


+ 85
- 0
training/model_trainers/model_trainer_base.py View File

@@ -0,0 +1,85 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn

if TYPE_CHECKING:
from torch.optim.lr_scheduler import LRScheduler


class ModelTrainerBase(ABC):
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: LRScheduler | None = None,
device: torch.device = "cpu",
) -> None:
self.model = model.to(device)
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device

def save(self, save_dpath: str, step: int) -> None:
"""
Save the model and optimizer state to the specified directory.

Args:
save_dpath (str): The directory path where the model and optimizer state will be saved.
"""

data = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"step": step,
}

if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()

torch.save(
data,
f"{save_dpath}/model-{step}.pt",
)

def load(self, load_fpath: str) -> None:
"""
Load the model and optimizer state from the specified directory.

Args:
load_dpath (str): The directory path from which the model and
optimizer state will be loaded.
"""
data = torch.load(load_fpath, map_location="cpu")
model_state_dict = data["model"]
self.model.load_state_dict(model_state_dict) # move everything to the correct device
self.model.to(self.device)
optimizer_state_dict = data["optimizer"]
self.optimizer.load_state_dict(optimizer_state_dict)
if self.scheduler is not None:
scheduler_data = data.get("scheduler", {})
self.scheduler.load_state_dict(scheduler_data)

return data["step"]

@abstractmethod
def _loss(self, batch: torch.Tensor) -> torch.Tensor:
"""
Compute the loss for a batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of data.

Returns:
torch.Tensor: The computed loss for the batch.
"""
...

@abstractmethod
def train_step(self, batch: torch.Tensor) -> float: ...

@abstractmethod
def eval_step(self, batch: torch.Tensor) -> float: ...

+ 0
- 0
training/model_trainers/translation/__init__.py View File


BIN
training/model_trainers/translation/__pycache__/__init__.cpython-310.pyc View File


BIN
training/model_trainers/translation/__pycache__/trans_e_trainer.cpython-310.pyc View File


+ 148
- 0
training/model_trainers/translation/trans_e_trainer.py View File

@@ -0,0 +1,148 @@
import torch
import torch.nn.functional as F
from einops import repeat
from pykeen.sampling import BernoulliNegativeSampler
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.instances import SLCWABatch
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from metrics.ranking import simple_ranking
from models.translation.trans_e import TransE
from tools.params import TrainingParams
from training.model_trainers.model_trainer_base import ModelTrainerBase


class TransETrainer(ModelTrainerBase):
def __init__(
self,
model: TransE,
training_params: TrainingParams,
triples_factory: torch.Tensor,
mapped_triples: torch.Tensor,
device: torch.device,
margin: float = 1.0,
n_negative: int = 1,
loss_fn: str = "margin",
**kwargs,
) -> None:
optimizer = Adam(
model.inner_model.parameters(),
lr=training_params.lr,
weight_decay=training_params.weight_decay,
)

# scheduler = None # Placeholder for scheduler, can be implemented later
scheduler = StepLR(
optimizer,
step_size=training_params.lr_step,
gamma=training_params.gamma,
)

super().__init__(model, optimizer, scheduler, device)

self.triples_factory = triples_factory

self.margin = margin
self.n_negative = n_negative
self.loss_fn = loss_fn

self.negative_sampler = BernoulliNegativeSampler(
mapped_triples=mapped_triples.to(device),
num_entities=model.num_entities,
num_relations=model.num_relations,
num_negs_per_pos=n_negative,
).to(device)

self.training_loop = SLCWATrainingLoop(
model=model.inner_model,
triples_factory=triples_factory,
optimizer=optimizer,
negative_sampler=self.negative_sampler,
lr_scheduler=scheduler,
)

def create_data_loader(
self,
triples_factory: torch.Tensor,
batch_size: int,
shuffle: bool = True,
) -> DataLoader:
triples_factory = self.triples_factory if triples_factory is None else triples_factory
return self.training_loop._create_training_data_loader(
triples_factory=triples_factory,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)

def _loss(self, pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor:
"""
Compute the loss based on the positive and negative scores.

Args:
pos_scores (torch.Tensor): Scores for positive triples.
neg_scores (torch.Tensor): Scores for negative triples.

Returns:
torch.Tensor: The computed loss.
"""

k = neg_scores.shape[1]

if self.loss_fn == "margin":
target = -torch.ones_like(neg_scores) # want s_pos < s_neg
loss = F.margin_ranking_loss(
repeat(pos_scores, "b -> b k", k=k),
neg_scores,
target,
margin=self.margin,
)
return loss
if self.loss_fn == "adversarial":
# b = pos_scores.size(0)
# neg_scores = rearrange(neg_scores, "(b k) -> b k", b=b, k=k)

# importance weights (detach so gradients flow only through log-sigmoid terms)
w = F.softmax(self.adv_temp * neg_scores, dim=1).detach()

pos_term = -F.logsigmoid(self.margin - pos_scores) # (B,)
neg_term = -(w * F.logsigmoid(neg_scores - self.margin)).sum(dim=1) # (B,)

loss = (pos_term + neg_term).mean()
return loss
return None

def train_step(self, batch: SLCWABatch, step: int) -> float:
"""
Perform a training step on the given batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of triples.

Returns:
float: The computed loss for the batch.
"""

continue_training = step > 0

loss = self.training_loop.train(
triples_factory=self.triples_factory,
num_epochs=step + 1,
batch_size=self.training_loop._get_batch_size(batch),
continue_training=continue_training,
use_tqdm=False,
use_tqdm_batch=False,
label_smoothing=0.0, # Assuming no label smoothing for simplicity
# num_workers=3,
)[-1]

return loss

def eval_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Return rank tensor for this batch (needed for MRR / Hits@K)."""
self.model.eval()
with torch.no_grad():
return simple_ranking(self.model, batch.to(self.device))

+ 121
- 0
training/model_trainers/translation/trans_h_trainer.py View File

@@ -0,0 +1,121 @@
import torch
from pykeen.losses import MarginRankingLoss
from pykeen.sampling import BernoulliNegativeSampler
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.instances import SLCWABatch
from torch.optim import Adam

from metrics.ranking import simple_ranking
from models.translation.trans_h import TransH
from tools.params import TrainingParams
from tools.train import StepMinLR
from training.model_trainers.model_trainer_base import ModelTrainerBase


class TransHTrainer(ModelTrainerBase):
def __init__(
self,
model: TransH,
training_params: TrainingParams,
triples_factory: torch.Tensor,
mapped_triples: torch.Tensor,
device: torch.device,
margin: float = 1.0,
n_negative: int = 1,
regul_rate: float = 0.0,
loss_fn: str = "margin",
**kwargs,
) -> None:
optimizer = Adam(
model.inner_model.parameters(),
lr=training_params.lr,
weight_decay=training_params.weight_decay,
)

scheduler = StepMinLR(
optimizer,
step_size=training_params.lr_step,
gamma=training_params.gamma,
min_lr=training_params.min_lr,
)

super().__init__(model, optimizer, scheduler, device)

self.triples_factory = triples_factory

self.margin = margin
self.n_negative = n_negative
self.regul_rate = regul_rate
self.loss_fn = loss_fn

self.negative_sampler = BernoulliNegativeSampler(
mapped_triples=mapped_triples.to(device),
num_entities=model.num_entities,
num_relations=model.num_relations,
num_negs_per_pos=n_negative,
).to(device)

self.training_loop = SLCWATrainingLoop(
model=model.inner_model,
triples_factory=triples_factory,
optimizer=optimizer,
negative_sampler=self.negative_sampler,
lr_scheduler=scheduler,
)

def _loss(self) -> torch.Tensor:
return self.model.loss

def create_data_loader(
self,
triples_factory: torch.Tensor,
batch_size: int,
shuffle: bool = True,
) -> torch.utils.data.DataLoader:
triples_factory = self.triples_factory if triples_factory is None else triples_factory
return self.training_loop._create_training_data_loader(
triples_factory=triples_factory,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)

@property
def loss(self) -> torch.Tensor:
return MarginRankingLoss(margin=self.margin, reduction="mean")

def train_step(
self,
batch: SLCWABatch,
step: int,
) -> float:
"""
Perform a training step on the given batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of triples.

Returns:
float: The computed loss for the batch.
"""

continue_training = step > 0

loss = self.training_loop.train(
triples_factory=self.triples_factory,
num_epochs=step + 1,
batch_size=self.training_loop._get_batch_size(batch),
continue_training=continue_training,
use_tqdm=False,
use_tqdm_batch=False,
label_smoothing=0.0, # Assuming no label smoothing for simplicity
)[-1]

return loss

def eval_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Return rank tensor for this batch (needed for MRR / Hits@K)."""
self.model.eval()
with torch.no_grad():
return simple_ranking(self.model, batch.to(self.device))

+ 116
- 0
training/model_trainers/translation/trans_r_trainer.py View File

@@ -0,0 +1,116 @@
import torch
from pykeen.sampling import BernoulliNegativeSampler
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.instances import SLCWABatch
from torch.optim import Adam

from metrics.ranking import simple_ranking
from models.translation.trans_r import TransR
from tools.params import TrainingParams
from tools.train import StepMinLR
from training.model_trainers.model_trainer_base import ModelTrainerBase


class TransRTrainer(ModelTrainerBase):
def __init__(
self,
model: TransR,
training_params: TrainingParams,
triples_factory: torch.Tensor,
mapped_triples: torch.Tensor,
device: torch.device,
margin: float = 1.0,
n_negative: int = 1,
regul_rate: float = 0.0,
loss_fn: str = "margin",
**kwargs,
) -> None:
optimizer = Adam(
model.inner_model.parameters(),
lr=training_params.lr,
weight_decay=training_params.weight_decay,
)

scheduler = StepMinLR(
optimizer,
step_size=training_params.lr_step,
gamma=training_params.gamma,
min_lr=training_params.min_lr,
)

super().__init__(model, optimizer, scheduler, device)

self.triples_factory = triples_factory

self.margin = margin
self.n_negative = n_negative
self.regul_rate = regul_rate
self.loss_fn = loss_fn

self.negative_sampler = BernoulliNegativeSampler(
mapped_triples=mapped_triples.to(device),
num_entities=model.num_entities,
num_relations=model.num_relations,
num_negs_per_pos=n_negative,
).to(device)

self.training_loop = SLCWATrainingLoop(
model=model.inner_model,
triples_factory=triples_factory,
optimizer=optimizer,
negative_sampler=self.negative_sampler,
lr_scheduler=scheduler,
)

def create_data_loader(
self,
triples_factory: torch.Tensor,
batch_size: int,
shuffle: bool = True,
) -> torch.utils.data.DataLoader:
triples_factory = self.triples_factory if triples_factory is None else triples_factory
return self.training_loop._create_training_data_loader(
triples_factory=triples_factory,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)

def _loss(self) -> torch.Tensor:
return self.model.loss

def train_step(
self,
batch: SLCWABatch,
step: int,
) -> float:
"""
Perform a training step on the given batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of triples.

Returns:
float: The computed loss for the batch.
"""

continue_training = step > 0

loss = self.training_loop.train(
triples_factory=self.triples_factory,
num_epochs=step + 1,
batch_size=self.training_loop._get_batch_size(batch),
continue_training=continue_training,
use_tqdm=False,
use_tqdm_batch=False,
label_smoothing=0.0, # Assuming no label smoothing for simplicity
)[-1]

return loss

def eval_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Return rank tensor for this batch (needed for MRR / Hits@K)."""
self.model.eval()
with torch.no_grad():
return simple_ranking(self.model, batch.to(self.device))

+ 214
- 0
training/trainer.py View File

@@ -0,0 +1,214 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch
from pykeen.evaluation import RankBasedEvaluator
from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from tools import (
CheckpointManager,
DeterministicRandomSampler,
TensorBoardHandler,
get_pretty_logger,
set_seed,
)

if TYPE_CHECKING:
from data.kg_dataset import KGDataset
from tools import CommonParams, TrainingParams

from .model_trainers.model_trainer_base import ModelTrainerBase


logger = get_pretty_logger(__name__)


cpu_device = torch.device("cpu")


class Trainer:
def __init__(
self,
train_dataset: KGDataset,
val_dataset: KGDataset | None,
model_trainer: ModelTrainerBase,
common_params: CommonParams,
training_params: TrainingParams,
device: torch.device = cpu_device,
) -> None:
set_seed(training_params.seed)

self.train_dataset = train_dataset
self.val_dataset = val_dataset if val_dataset is not None else train_dataset
self.model_trainer = model_trainer

self.common_params = common_params
self.training_params = training_params

self.device = device

self.train_loader = self.model_trainer.create_data_loader(
triples_factory=self.train_dataset.triples_factory,
batch_size=training_params.batch_size,
shuffle=True,
)

self.train_iterator = self._data_iterator(self.train_loader)

seed = 1234
val_sampler = DeterministicRandomSampler(
len(self.val_dataset),
sample_size=self.training_params.validation_sample_size,
seed=seed,
)
self.valid_loader = DataLoader(
self.val_dataset,
batch_size=training_params.validation_batch_size,
shuffle=False,
num_workers=training_params.num_workers,
sampler=val_sampler,
)
self.valid_iterator = self._data_iterator(self.valid_loader)

self.step = 0
self.log_every = common_params.log_every
self.log_console_every = common_params.log_console_every

self.save_dpath = common_params.save_dpath
self.save_every = common_params.save_every

self.checkpoint_manager = CheckpointManager(
root_directory=self.save_dpath,
run_name=common_params.run_name,
load_only=common_params.evaluate_only,
)
self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory
if not common_params.evaluate_only:
logger.info(f"Saving checkpoints to {self.ckpt_dpath}")

if self.common_params.load_path:
self.load()

self.writer = None
if not common_params.evaluate_only:
run_name = self.checkpoint_manager.run_name
self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}")
logger.addHandler(TensorBoardHandler(self.writer))

def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None:
"""
Log a message to the console and TensorBoard.

Args:
message (str): The message to log.
"""
if console_msg is not None:
logger.info(f"{console_msg}")
if self.writer is not None:
self.writer.add_scalar(tag, log_value, self.step)

def save(self) -> None:
save_dpath = self.ckpt_dpath

save_dpath = Path(save_dpath)
save_dpath.mkdir(parents=True, exist_ok=True)
self.model_trainer.save(save_dpath, self.step)

def load(self) -> None:
load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path)
load_fpath = Path(load_path)

if load_fpath.exists():
self.step = self.model_trainer.load(load_fpath)
logger.info(f"Model loaded from {load_fpath}.")
else:
msg = f"Load path {load_fpath} does not exist."
raise FileNotFoundError(msg)

def _data_iterator(self, loader: DataLoader) -> iter:
"""Yield batches of positive and negative triples."""
while True:
yield from loader

def train(self) -> None:
"""Run `n_steps` optimisation steps; return mean loss."""
data_batch = next(self.train_iterator)

loss = self.model_trainer.train_step(data_batch, self.step)

if self.step % self.log_console_every == 0:
logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}")
if self.step % self.log_every == 0:
self.log("train/loss", loss)

self.step += 1

def evaluate(self) -> dict[str, torch.Tensor]:
ks = [1, 3, 10]

metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks]

evaluator = RankBasedEvaluator(
metrics=metrics,
filtered=True,
batch_size=self.training_params.validation_batch_size,
)

result = evaluator.evaluate(
model=self.model_trainer.model,
mapped_triples=self.val_dataset.triples,
additional_filter_triples=[
self.train_dataset.triples,
self.val_dataset.triples,
],
batch_size=self.training_params.validation_batch_size,
)

mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank")
hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks}

metrics = {
"mrr": mrr,
**hits_at,
}

for k, v in metrics.items():
self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}")

# ───────────────────────── #
# master schedule
# ───────────────────────── #
def run(self) -> None:
"""
Alternate training and evaluation until `total_steps`
optimisation steps have been executed.
"""

total_steps = self.training_params.num_train_steps
eval_every = self.training_params.eval_every

while self.step < total_steps:
self.train()
if self.step % eval_every == 0:
logger.info(f"Evaluating at step {self.step}...")
self.evaluate()

if self.step % self.save_every == 0:
logger.info(f"Saving model at step {self.step}...")
self.save()

logger.info("Training complete.")

def reset(self) -> None:
"""
Reset the trainer state, including the model trainer and data iterators.
"""
self.model_trainer.reset()
self.train_iterator = self._data_iterator(self.train_loader)
self.valid_iterator = self._data_iterator(self.valid_loader)
self.step = 0
logger.info("Trainer state has been reset.")

Loading…
Cancel
Save