Browse Source

Added code base.

main
Naser Kazemi 3 weeks ago
commit
02a34881d1
71 changed files with 4858 additions and 0 deletions
  1. 1
    0
      README.md
  2. 106
    0
      build_crec_datasets.py
  3. 20
    0
      configs/common/common.yaml
  4. 9
    0
      configs/config.yaml
  5. 3
    0
      configs/data/data.yaml
  6. 8
    0
      configs/data/fb15k.yaml
  7. 8
    0
      configs/data/wn18.yaml
  8. 8
    0
      configs/data/wn18rr.yaml
  9. 9
    0
      configs/data/yago3_10.yaml
  10. 3
    0
      configs/model/model.yaml
  11. 9
    0
      configs/model/trans_e.yaml
  12. 11
    0
      configs/model/trans_h.yaml
  13. 12
    0
      configs/model/trans_r.yaml
  14. 21
    0
      configs/training/training.yaml
  15. 11
    0
      configs/training/trans_e_trainer.yaml
  16. 11
    0
      configs/training/trans_r_trainer.yaml
  17. 17
    0
      configs/trans_e_fb15k.yaml
  18. 17
    0
      configs/trans_e_wn18.yaml
  19. 18
    0
      configs/trans_e_wn18rr.yaml
  20. 17
    0
      configs/trans_e_yago3_10.yaml
  21. 17
    0
      configs/trans_r.yaml
  22. 6
    0
      data/__init__.py
  23. 49
    0
      data/fb15k.py
  24. 57
    0
      data/hationet.py
  25. 89
    0
      data/kg_dataset.py
  26. 57
    0
      data/open_bio_link.py
  27. 57
    0
      data/openke_wiki.py
  28. 112
    0
      data/wn18.py
  29. 70
    0
      data/yago3_10.py
  30. 38
    0
      eval_datasets.py
  31. 80
    0
      main.py
  32. 0
    0
      metrics/__init__.py
  33. 55
    0
      metrics/base_metric.py
  34. 264
    0
      metrics/c_swklf.py
  35. 732
    0
      metrics/crec_modifier.py
  36. 106
    0
      metrics/crec_radius_sample.py
  37. 226
    0
      metrics/greedy_crec.py
  38. 30
    0
      metrics/ranking.py
  39. 237
    0
      metrics/wlcrec.py
  40. 135
    0
      metrics/wlec.py
  41. 0
    0
      models/__init__.py
  42. 127
    0
      models/base_model.py
  43. 0
    0
      models/translation/__init__.py
  44. BIN
      models/translation/__pycache__/__init__.cpython-310.pyc
  45. BIN
      models/translation/__pycache__/trans_e.cpython-310.pyc
  46. 70
    0
      models/translation/trans_e.py
  47. 76
    0
      models/translation/trans_h.py
  48. 79
    0
      models/translation/trans_r.py
  49. 82
    0
      pyproject.toml
  50. 28
    0
      setup/install.sh
  51. 163
    0
      setup/kg_env.yaml
  52. 6
    0
      tools/__init__.py
  53. 221
    0
      tools/checkpoint_manager.py
  54. 45
    0
      tools/params.py
  55. 417
    0
      tools/pretty_logger.py
  56. 66
    0
      tools/sampling.py
  57. 21
    0
      tools/tb_handler.py
  58. 12
    0
      tools/train.py
  59. 25
    0
      tools/utils.py
  60. 0
    0
      training/__init__.py
  61. 0
    0
      training/model_trainers/__init__.py
  62. BIN
      training/model_trainers/__pycache__/__init__.cpython-310.pyc
  63. BIN
      training/model_trainers/__pycache__/model_trainer_base.cpython-310.pyc
  64. 85
    0
      training/model_trainers/model_trainer_base.py
  65. 0
    0
      training/model_trainers/translation/__init__.py
  66. BIN
      training/model_trainers/translation/__pycache__/__init__.cpython-310.pyc
  67. BIN
      training/model_trainers/translation/__pycache__/trans_e_trainer.cpython-310.pyc
  68. 148
    0
      training/model_trainers/translation/trans_e_trainer.py
  69. 121
    0
      training/model_trainers/translation/trans_h_trainer.py
  70. 116
    0
      training/model_trainers/translation/trans_r_trainer.py
  71. 214
    0
      training/trainer.py

+ 1
- 0
README.md View File

# KGEvaluation

+ 106
- 0
build_crec_datasets.py View File

from __future__ import annotations

import argparse
import logging

import torch

from data import YAGO310Dataset

# from metrics.crec_modifier import tune_crec_parallel_edits
from data.kg_dataset import KGDataset
from metrics.crec_radius_sample import find_sample
from metrics.wlcrec import WLCREC
from tools import get_pretty_logger

logging = get_pretty_logger(__name__)


# --------------------------------------------------------------------
# CLI
# --------------------------------------------------------------------
def main():
p = argparse.ArgumentParser()
p.add_argument("--depth", type=int, default=5)
p.add_argument("--lower", type=float, default=0.25)
p.add_argument("--upper", type=float, default=0.26)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--save", type=str, default="fb15k_crec_bi.pt")
args = p.parse_args()

# ds = FB15KDataset()
ds = YAGO310Dataset()
triples = ds.triples
n_ent, n_rel = ds.num_entities, ds.num_relations
logging.info("YAGO3-10 loaded |V|=%d |R|=%d |T|=%d", n_ent, n_rel, len(triples))

# wlcrec = WLCREC(ds)
# c = wlcrec.compute(5)[-2] # C_NWLEC
# colours = wlcrec.wl_colours(args.depth)[-1]

# tuned = tune_crec(
# wlcrec, n_ent, n_rel, depth=args.depth, lo=args.lower, hi=args.upper, seed=args.seed
# )

# tuned = tune_crec_parallel_edits(
# triples,
# n_ent,
# n_rel,
# depth=args.depth,
# lo=args.lower,
# hi=args.upper,
# seed=args.seed,
# )

# tuned = fast_tune_crec(
# triples.tolist(),
# colours,
# n_ent,
# n_rel,
# depth=args.depth,
# c=c,
# lo=args.lower,
# hi=args.upper,
# max_iters=1000,
# max_workers=10, # Adjust number of workers as needed
# )

tuned = find_sample(
triples.tolist(),
n_ent,
n_rel,
depth=args.depth,
ratio=0.1, # Adjust ratio as needed
radius=2, # Adjust radius as needed
lower=args.lower,
upper=args.upper,
max_tries=1000,
seed=58,
)

tuned_ds = KGDataset(
triples_factory=None,
triples=torch.tensor(tuned, dtype=torch.long),
num_entities=n_ent,
num_relations=n_rel,
split="train",
)

tuned_wlcrec = WLCREC(tuned_ds)
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = tuned_wlcrec.compute(5)
print(
f"\nTuned CREC results (H={5})",
f"\n • avg WLEC : {entropy:.6f} nats",
f"\n • C_ratio : {c_ratio:.6f}",
f"\n • C_NWLEC : {c_nwlec:.6f}",
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats",
f"\n • D_ratio : {d_ratio:.6f}",
f"\n • D_NWLEC : {d_nwlec:.6f}",
)

torch.save(torch.tensor(tuned, dtype=torch.long), args.save)
logging.info("Saved tuned triples → %s", args.save)


if __name__ == "__main__":
main()

+ 20
- 0
configs/common/common.yaml View File

_target_: tools.params.CommonParams

model_name: "default_model"
project_name: "my_project"
run_name: ...

# Where to write all outputs (checkpoints / tensorboard logs / etc.)
save_dpath: "./checkpoints"
save_every: 200

# If you want to resume from a checkpoint, set load_dpath to a string path;
# otherwise leave it null or empty.
load_path: null

# Default log directory
log_dir: "logs"
log_every: 10
log_console_every: 10

evaluate_only: false

+ 9
- 0
configs/config.yaml View File

defaults:
- common: common
- model: model
- data: data
- training: training

hydra:
run:
dir: "./outputs"

+ 3
- 0
configs/data/data.yaml View File

_target_: data.wn18rr_dataset.WN18RRDataset

split: train

+ 8
- 0
configs/data/fb15k.yaml View File

train:
_target_: data.fb15k.FB15KDataset
split: train


valid:
_target_: data.fb15k.FB15KDataset
split: valid

+ 8
- 0
configs/data/wn18.yaml View File

train:
_target_: data.wn18.WN18Dataset
split: train


valid:
_target_: data.wn18.WN18Dataset
split: valid

+ 8
- 0
configs/data/wn18rr.yaml View File

train:
_target_: data.wn18.WN18RRDataset
split: train


valid:
_target_: data.wn18.WN18RRDataset
split: valid

+ 9
- 0
configs/data/yago3_10.yaml View File

train:
_target_: data.yago3_10.YAGO310Dataset
split: train


valid:
_target_: data.yago3_10.YAGO310Dataset
split: valid


+ 3
- 0
configs/model/model.yaml View File

num_entities: 100
num_relations: 100
dim: 100

+ 9
- 0
configs/model/trans_e.yaml View File

defaults:
- model
- _self_

_target_: models.translation.trans_e.TransE

dim: 200
sf_norm: 1
p_norm: true

+ 11
- 0
configs/model/trans_h.yaml View File

defaults:
- model
- _self_

_target_: models.translation.trans_h.TransH

dim: 200
sf_norm: 1
p_norm: true
p_norm_value: 2
margin: 1.0

+ 12
- 0
configs/model/trans_r.yaml View File

defaults:
- model
- _self_

_target_: models.translation.trans_r.TransR

entity_dim: 200
relation_dim: 30
sf_norm: 1
p_norm: true
p_norm_value: 2
margin: 1.0

+ 21
- 0
configs/training/training.yaml View File

_target_: tools.params.TrainingParams

lr: 0.005
min_lr: 0.0001
weight_decay: 0.0
t0: 50
lr_step: 10
gamma: 0.95
batch_size: 8192
num_workers: 3
seed: 42

num_train_steps: 10000
eval_every: 100

validation_sample_size: 1000
validation_batch_size: 64


model_trainer:
_target_: training.model_trainers.model_trainer_base.ModelTrainerBase

+ 11
- 0
configs/training/trans_e_trainer.yaml View File

defaults:
- training
- _self_


model_trainer:
_target_: training.model_trainers.translation.trans_e_trainer.TransETrainer
margin: 1.0
n_negative: 5
regul_rate: 0.0
loss_fn: "margin"

+ 11
- 0
configs/training/trans_r_trainer.yaml View File

defaults:
- training
- _self_


model_trainer:
_target_: training.model_trainers.translation.trans_r_trainer.TransRTrainer
margin: 1.0
n_negative: 10
regul_rate: 0.0
loss_fn: "margin"

+ 17
- 0
configs/trans_e_fb15k.yaml View File

defaults:
- config
- override common: common
- override data: fb15k
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransE_FB15K"
evaluate_only: true
load_path: (2, 1800)

+ 17
- 0
configs/trans_e_wn18.yaml View File

defaults:
- config
- override common: common
- override data: wn18
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransE_WN18"
evaluate_only: true
load_path: (82, 3000)

+ 18
- 0
configs/trans_e_wn18rr.yaml View File

defaults:
- config
- override common: common
- override data: wn18rr
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"


common:
run_name: "TransE_WN18rr"
evaluate_only: true
load_path: (2, 6000)

+ 17
- 0
configs/trans_e_yago3_10.yaml View File

defaults:
- config
- override common: common
- override data: yago3_10
- override model: trans_e
- override training: trans_e_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransE_YAGO310"
evaluate_only: true
load_path: (1, 2600)

+ 17
- 0
configs/trans_r.yaml View File

defaults:
- config
- override common: common
- override data: wn18
- override model: trans_r
- override training: trans_r_trainer
- _self_


training:
model_trainer:
loss_fn: "margin"

common:
run_name: "TransR_WN18"
# evaluate_only: true
# load_path: (1, 1000)

+ 6
- 0
data/__init__.py View File

from .fb15k import FB15KDataset
from .wn18 import WN18Dataset, WN18RRDataset
from .yago3_10 import YAGO310Dataset
from .hationet import HetionetDataset
from .open_bio_link import OpenBioLinkDataset
from .openke_wiki import OpenKEWikiDataset

+ 49
- 0
data/fb15k.py View File

from typing import List, Tuple

import torch
from pykeen.datasets import FB15k

from .kg_dataset import KGDataset


class FB15KDataset(KGDataset):
"""A wrapper around PyKeen's FB15k dataset producing KGDataset instances."""

def __init__(self, split: str = "train") -> None:
dataset = FB15k()

# if split == "train":
if True:
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'"
raise ValueError(msg)

custom_triples = torch.load(
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt",
map_location=torch.device("cpu"),
).to(torch.long)
# .tolist()

tf = tf.clone_and_exchange_triples(
custom_triples,
)

triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=dataset.num_entities,
num_relations=dataset.num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
return torch.tensor(super().__getitem__(idx), dtype=torch.long)

+ 57
- 0
data/hationet.py View File

from typing import List, Tuple

import torch
from pykeen.datasets import Hetionet

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class HetionetDataset(KGDataset):
"""A wrapper around PyKeen's Hetionet dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = Hetionet()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 89
- 0
data/kg_dataset.py View File

"""PyTorch Dataset wrapping indexed triples (h, r, t)."""

from pathlib import Path
from typing import Dict, List, Tuple

import torch
from pykeen.typing import MappedTriples
from torch.utils.data import Dataset


class KGDataset(Dataset):
"""PyTorch Dataset wrapping indexed triples (h, r, t)."""

def __init__(
self,
triples_factory: MappedTriples,
triples: List[Tuple[int, int, int]],
num_entities: int,
num_relations: int,
split: str = "train",
) -> None:
super().__init__()
self.triples_factory = triples_factory
self.triples = torch.tensor(triples, dtype=torch.long)
self.num_entities = num_entities
self.num_relations = num_relations
self.split = split

def __len__(self) -> int:
"""Return the number of triples in the dataset."""
return len(self.triples)

def __getitem__(self, idx: int) -> Tuple[int, int, int]:
"""Return the indexed triple at the given index."""
return self.triples[idx]

def entities(self) -> torch.Tensor:
"""Return a list of all entities in the dataset."""
return torch.arange(self.num_entities)

def relations(self) -> torch.Tensor:
"""Return a list of all relations in the dataset."""
return torch.arange(self.num_relations)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(split={self.split!r}, "
f"num_triples={len(self.triples)}, "
f"num_entities={self.num_entities}, "
f"num_relations={self.num_relations})"
)


# ---------------------------- Helper functions --------------------------------


def _map_to_ids(
triples: List[Tuple[str, str, str]],
) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]:
ent2id: Dict[str, int] = {}
rel2id: Dict[str, int] = {}
mapped: List[Tuple[int, int, int]] = []

def _id(d: Dict[str, int], k: str) -> int:
d.setdefault(k, len(d))
return d[k]

for h, r, t in triples:
mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t)))

return mapped, ent2id, rel2id


def load_kg_from_files(
root: str, splits: tuple[str] = ("train", "valid", "test")
) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]:
root = Path(root)
split_raw, all_raw = {}, []
for s in splits:
lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()]
split_raw[s] = lines
all_raw.extend(lines)
mapped, ent2id, rel2id = _map_to_ids(all_raw)
datasets, start = {}, 0
for s in splits:
end = start + len(split_raw[s])
datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s)
start = end
return datasets, ent2id, rel2id

+ 57
- 0
data/open_bio_link.py View File

from typing import List, Tuple

import torch
from pykeen.datasets import OpenBioLink

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class OpenBioLinkDataset(KGDataset):
"""A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = OpenBioLink()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 57
- 0
data/openke_wiki.py View File

from typing import List, Tuple

import torch
from pykeen.datasets import Wikidata5M

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class OpenKEWikiDataset(KGDataset):
"""A wrapper around PyKeen's OpenKE-Wiki (WikiN3) dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = Wikidata5M()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 112
- 0
data/wn18.py View File

from typing import List, Tuple

import torch
from pykeen.datasets import WN18, WN18RR

from .kg_dataset import KGDataset # Importing KGDataset from the same module


class WN18Dataset(KGDataset):
"""A wrapper around PyKeen's WN18 dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18 dataset
dataset = WN18()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)


# Assuming KGDataset is defined in the same module or already imported:
# from your_module import KGDataset


class WN18RRDataset(KGDataset):
"""A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances."""

def __init__(self, split: str = "train") -> None:
"""
Initialize the WN18RR wrapper.

:param split: One of "train", "valid", or "test" indicating which split to load.
"""
# Load the PyKeen WN18RR dataset
dataset = WN18RR()

# Select the appropriate TriplesFactory based on the requested split
if split == "train":
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
raise ValueError(
msg,
)

# `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
# Convert each row into a Python tuple of ints
triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

# Get the total number of entities and relations in the entire dataset
num_entities = dataset.num_entities
num_relations = dataset.num_relations

# Initialize the base KGDataset with the selected triples
super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=num_entities,
num_relations=num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
data = super().__getitem__(idx)

# Convert the tuple to a torch tensor
return torch.tensor(data, dtype=torch.long)

+ 70
- 0
data/yago3_10.py View File

from typing import List, Tuple

import torch
from pykeen.datasets import PathDataset

from .kg_dataset import KGDataset


class YAGO310Dataset(KGDataset):
"""Wrapper around PyKeen's YAGO3-10 dataset."""

def __init__(self, split: str = "train") -> None:
# dataset = YAGO310()

dataset = PathDataset(
training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt",
testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt",
validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt",
)

if split == "train":
# if True:
tf = dataset.training
elif split in ("valid", "val", "validation"):
tf = dataset.validation
elif split in ("test", "testing"):
tf = dataset.testing
else:
msg = f"Unrecognized split '{split}'"
raise ValueError(msg)

custom_triples = torch.load(
"/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt",
map_location=torch.device("cpu"),
).to(torch.long)
# .tolist()

# get a random sample of size 5000
# seed = torch.seed() # For reproducibility
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# if custom_triples.shape[0] > 5000:
# custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]]

tf = tf.clone_and_exchange_triples(
custom_triples,
)

# triples = tf.mapped_triples
# # get a random sample of size 5000
# if triples.shape[0] > 5000:
# indices = torch.randperm(triples.shape[0])[:5000]
# triples = triples[indices]

# tf = tf.clone_and_exchange_triples(triples)

triples_list: List[Tuple[int, int, int]] = [
(int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
]

super().__init__(
triples_factory=tf,
triples=triples_list,
num_entities=dataset.num_entities,
num_relations=dataset.num_relations,
split=split,
)

def __getitem__(self, idx: int) -> torch.Tensor:
return torch.tensor(super().__getitem__(idx), dtype=torch.long)

+ 38
- 0
eval_datasets.py View File

from data import (
FB15KDataset,
WN18Dataset,
WN18RRDataset,
YAGO310Dataset,
)
from metrics.wlcrec import WLCREC


def main() -> None:
datasets = {
"YAGO3-10": YAGO310Dataset(split="train"),
# "WN18": WN18Dataset(split="train"),
# "WN18RR": WN18RRDataset(split="train"),
# "FB15K": FB15KDataset(split="train"),
# "Hetionet": HetionetDataset(split="train"),
# "OpenBioLink": OpenBioLinkDataset(split="train"),
# "OpenKEWiki": OpenKEWikiDataset(split="train"),
}

for name, dataset in datasets.items():
wl_crec = WLCREC(dataset)
entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec = wl_crec.compute(H=20)
print(
f"\nDataset: {name}",
f"\nResults (H={20})",
f"\n • avg WLEC : {entropy:.6f} nats",
f"\n • C_ratio : {c_ratio:.6f}",
f"\n • C_NWLEC : {c_nwlec:.6f}",
f"\n • H_cond(R|S_H) : {h_cond:.6f} nats",
f"\n • D_ratio : {d_ratio:.6f}",
f"\n • D_NWLEC : {d_nwlec:.6f}",
sep="",
)


if __name__ == "__main__":
main()

+ 80
- 0
main.py View File

from copy import deepcopy

import hydra
import lovely_tensors as lt
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig

from metrics.c_swklf import CSWKLF
from tools import get_pretty_logger

# (Import whatever Trainer or runner you have)
from training.trainer import Trainer

logger = get_pretty_logger(__name__)

lt.monkey_patch()


@hydra.main(config_path="configs", config_name="config")
def main(cfg: DictConfig) -> None:
# Detect CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

common_params = cfg.common
training_params = deepcopy(cfg.training)
# drop the model_trainer from cfg.trainer
del training_params.model_trainer

common_params = instantiate(common_params)
training_params = instantiate(training_params)

# dataset instantiation
train_dataset = instantiate(cfg.data.train)
val_dataset = instantiate(cfg.data.valid)

model = instantiate(
cfg.model,
num_entities=train_dataset.num_entities,
num_relations=train_dataset.num_relations,
triples_factory=train_dataset.triples_factory,
device=device,
)

model = model.to(device)

model_trainer = instantiate(
cfg.training.model_trainer,
model=model,
training_params=training_params,
triples_factory=train_dataset.triples_factory,
mapped_triples=train_dataset.triples,
device=model.device,
)

# Instantiate the Trainer with the model_trainer
trainer = Trainer(
train_dataset=train_dataset,
val_dataset=val_dataset,
model_trainer=model_trainer,
common_params=common_params,
training_params=training_params,
device=model.device,
)

if cfg.common.evaluate_only:
trainer.evaluate()
c_swklf_metric = CSWKLF(
dataset=val_dataset,
model=model,
)
c_swklf_value = c_swklf_metric.compute()
logger.info(f"CSWKLF Metric Value: {c_swklf_value}")

else:
trainer.run()


if __name__ == "__main__":
main()

+ 0
- 0
metrics/__init__.py View File


+ 55
- 0
metrics/base_metric.py View File

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Tuple

if TYPE_CHECKING:
from data.kg_dataset import KGDataset


class BaseMetric(ABC):
"""Base class for metrics that compute Weisfeiler-Lehman Entropy
Complexity (WLEC) or similar metrics.
"""

def __init__(self, dataset: KGDataset) -> None:
self.dataset = dataset
self.num_entities = dataset.num_entities

self.out_adj, self.in_adj = self._build_adjacency_lists()

def _build_adjacency_lists(
self,
) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
"""Create *in* and *out* adjacency lists from mapped triples.

Parameters
----------
triples:
Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``).
num_entities:
Total number of entities. Needed to allocate adjacency containers.

Returns
-------
out_adj, in_adj
Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element
``in_adj[v]`` is a list ``[(r, h), ...]``.
"""

triples = self.dataset.triples # (m, 3)
num_entities = self.num_entities

out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)]
in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)]
for h, r, t in triples.tolist():
out_adj[h].append((r, t))
in_adj[t].append((r, h))
return out_adj, in_adj

@abstractmethod
def compute(
self,
*args,
**kwargs,
) -> Any: ...

+ 264
- 0
metrics/c_swklf.py View File

from __future__ import annotations

import math
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Iterable

import torch
from torch import Tensor

from .base_metric import BaseMetric

if TYPE_CHECKING:
from data.kg_dataset import KGDataset
from models.base_model import ERModel


@torch.no_grad()
def _log_prob_distribution(
fn: Callable[[Tensor], Tensor],
query: Tensor,
batch_size: int,
) -> Iterable[Tensor]:
"""
Yield *log*-probability batches for an arbitrary distribution helper.

Parameters
----------
fn :
A bound method of *model* returning **log**-probabilities, e.g.
``model.tail_distribution(..., log=True)``.
query :
LongTensor of shape (N, 2) containing the query IDs that `fn` expects.
batch_size :
Mini-batch size. Choose according to available GPU/CPU memory.
"""
for i in range(0, len(query), batch_size):
yield fn(query[i : i + batch_size]) # (B, |Candidates|)


def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor:
"""
KL(q || p) where q is uniform over *true_index_sets*.

Parameters
----------
log_p :
Tensor of shape (B, N) — log-probabilities from the model.
true_index_sets :
List (length B). Element *b* contains the list of **indices**
that are correct for row *b*.

Returns
-------
kl : Tensor, shape (B,) — KL divergence per row (natural log).
"""
rows: list[Tensor] = []
for lp_row, idx in zip(log_p, true_index_sets):
k = len(idx)
rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p]
return torch.tensor(rows, device=log_p.device)


@dataclass(slots=True)
class _Accum:
"""Simple accumulator for ∑ value and ∑ weight."""

tot_v: float = 0.0
tot_w: float = 0.0

def update(self, value: float, weight: float) -> None:
self.tot_v += value * weight
self.tot_w += weight

@property
def mean(self) -> float:
return self.tot_v / self.tot_w if self.tot_w else float("nan")


class CSWKLF(BaseMetric):
"""
C-SWKLF metric for evaluating knowledge graph embeddings.
This metric is used to evaluate the quality of knowledge graph embeddings
by computing the C-SWKLF score.
"""

def __init__(self, dataset: KGDataset, model: ERModel) -> None:
super().__init__(dataset)
self.model = model

@torch.no_grad()
def compute(
self,
alpha: float = 1 / 3,
beta: float = 1 / 3,
gamma: float = 1 / 3,
batch_size: int = 1024,
slice_size: int | None = 2048,
device: torch.device | str | None = None,
) -> float:
"""
Compute *Comprehensive Structural-Weighted KL-Fitness*.

Parameters
----------
model :
A trained PyKEEN ERModel **with** the three distribution helpers.
kg :
`pykeen.triples.KGInfo` instance providing `mapped_triples`.
alpha, beta, gamma :
Non-negative weights for the three query types, default 1/3 each.
Must sum to *1*.
batch_size :
Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM.
slice_size :
Forwarded to `tail_distribution` / `head_distribution` /
`relation_distribution`. `None` disables slicing.
device :
Where to do the computation. `None` ⇒ first param's device.

Returns
-------
score : float in (0, 1]
Higher = model closer to empirical distributions across *all* directions.
"""
# --------------------------------------------------------------------- #
# Preparation #
# --------------------------------------------------------------------- #
assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1."

model_device = next(self.model.parameters()).device
if device is None:
device = model_device
triples = self.dataset.triples.to(device) # (|T|, 3)

heads, rels, tails = triples.t() # (|T|,)

# Build index structures --------------------------------------------------
tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list)
heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list)
rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list)

for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()):
tails_by_hr[(h, r)].append(t)
heads_by_rt[(r, t)].append(h)
rels_by_ht[(h, t)].append(r)

# Structural entropies & query lists ---------------------------------------
H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r
H_head: dict[int, _Accum] = defaultdict(_Accum)
tail_queries: list[tuple[int, int]] = [] # (h,r)
head_queries: list[tuple[int, int]] = [] # (r,t)
rel_queries: list[tuple[int, int]] = [] # (h,t)

# --- tails --------------------------------------------------------------
for (h, r), ts in tails_by_hr.items():
k = len(ts)
H_tail[r].update(math.log(k), 1.0)
tail_queries.append((h, r))

# --- heads --------------------------------------------------------------
for (r, t), hs in heads_by_rt.items():
k = len(hs)
H_head[r].update(math.log(k), 1.0)
head_queries.append((r, t))

# --- relations ----------------------------------------------------------
H_rel_accum = _Accum()
for (h, t), rs in rels_by_ht.items():
k = len(rs)
H_rel_accum.update(math.log(k), 1.0)
rel_queries.append((h, t))

H_struct_tail = {r: acc.mean for r, acc in H_tail.items()}
H_struct_head = {r: acc.mean for r, acc in H_head.items()}
# H_struct_rel = H_rel_accum.mean # scalar

# --------------------------------------------------------------------- #
# KL divergences #
# --------------------------------------------------------------------- #
def _avg_kl_per_relation(
queries: list[tuple[int, int]],
true_idx_map: dict[tuple[int, int], list[int]],
distribution_fn: Callable[[Tensor], Tensor],
num_relations: bool = False, # switch to know output size
) -> dict[int, float]:
"""
Compute *average* KL(q || p) **per relation**.

Works for the tail & head conditionals.
"""
kl_sum: dict[int, float] = defaultdict(float)
count: dict[int, int] = defaultdict(int)

query_tensor = torch.tensor(queries, dtype=torch.long, device=device)

for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size):
# slice_size handled inside distribution_fn
# log_p : (B, |Candidates|)
b = log_p.shape[0]
# Map back to current slice of queries
slice_queries = queries[:b]
queries[:] = queries[b:] # consume

true_sets = [true_idx_map[q] for q in slice_queries]
kl_row = _kl_uniform(log_p, true_sets) # (B,)

for (x, r), kl_val in zip(slice_queries, kl_row.tolist()):
rel_id = r if not num_relations else x
kl_sum[rel_id] += kl_val
count[rel_id] += 1

return {r: kl_sum[r] / count[r] for r in kl_sum}

# --- tails --------------------------------------------------------------
tail_queries_copy = tail_queries.copy()
D_tail = _avg_kl_per_relation(
tail_queries_copy,
tails_by_hr,
lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s),
)
S_tail = {r: math.exp(-d) for r, d in D_tail.items()}

# --- heads --------------------------------------------------------------
head_queries_copy = head_queries.copy()
D_head = _avg_kl_per_relation(
head_queries_copy,
heads_by_rt,
lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s),
num_relations=True,
)
S_head = {r: math.exp(-d) for r, d in D_head.items()}

# --- relations (single global number) ----------------------------------
D_rel_sum = 0.0
for log_p in _log_prob_distribution(
lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s),
torch.tensor(rel_queries, dtype=torch.long, device=device),
batch_size,
):
slice_size_batch = log_p.shape[0]
slice_queries = rel_queries[:slice_size_batch]
rel_queries[:] = rel_queries[slice_size_batch:]

true_sets = [rels_by_ht[q] for q in slice_queries]
D_rel_sum += _kl_uniform(log_p, true_sets).sum().item()

D_rel = D_rel_sum / H_rel_accum.tot_w
S_rel = math.exp(-D_rel)

# --------------------------------------------------------------------- #
# Weighted aggregation → C-SWKLF #
# --------------------------------------------------------------------- #
def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float:
num = sum(weight[r] * score[r] for r in score)
den = sum(weight[r] for r in score)
return num / den if den else float("nan")

sw_tail = _weighted_mean(S_tail, H_struct_tail)
sw_head = _weighted_mean(S_head, H_struct_head)
# Relations: numerator & denominator cancel
sw_rel = S_rel

return alpha * sw_tail + beta * sw_head + gamma * sw_rel

+ 732
- 0
metrics/crec_modifier.py View File

# from __future__ import annotations

# import math
# import os
# import random
# import threading
# from collections import defaultdict
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from typing import Dict, List, Tuple

# import torch

# from data.kg_dataset import KGDataset
# from metrics.wlcrec import WLCREC # Assuming WLCREC is defined in
# from tools import get_pretty_logger

# logger = get_pretty_logger(__name__)


# # --------------------------------------------------------------------
# # Edge-editing primitives
# # --------------------------------------------------------------------
# # -- additions --------------------
# def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng):
# for _ in range(1000):
# h = rng.randrange(n_ent)
# t = rng.randrange(n_ent)
# if colours[h] != colours[t]:
# r = rng.randrange(n_rel)
# triples.append((h, r, t))
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))
# return


# def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng):
# σ = rng.choice(colours)
# τ = rng.choice(colours)
# h = colours.index(σ)
# t = colours.index(τ)
# triples.append((h, 0, t))
# out_adj[h].append((0, t))
# in_adj[t].append((0, h))


# # -- removals --------------------


# def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng):
# cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
# if cand:
# idx = rng.choice(cand)
# h, r, t = triples.pop(idx)
# out_adj[h].remove((r, t))
# in_adj[t].remove((r, h))


# def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng):
# sig_rel = defaultdict(set)
# for h, r, t in triples:
# σ, τ = colours[h], colours[t]
# sig_rel[(σ, τ)].add(r)
# cand = []
# for i, (h, r, t) in enumerate(triples):
# if len(sig_rel[(colours[h], colours[t])]) == 1:
# cand.append(i)
# if cand:
# idx = rng.choice(cand)
# h, r, t = triples.pop(idx)
# out_adj[h].remove((r, t))
# in_adj[t].remove((r, h))


# def _search_worker(
# seed: int,
# triples_init: List[Tuple[int, int, int]],
# triples_factory,
# n_ent: int,
# n_rel: int,
# depth: int,
# lo: float,
# hi: float,
# max_iters: int,
# ) -> List[Tuple[int, int, int]] | None:
# """
# Run the exact same hill‑climb that `tune_crec()` did,
# but entirely in this process. Return the edited triples
# once c falls in [lo, hi]; return None if we used up all iterations.
# """
# rng = random.Random(seed)
# triples = triples_init.copy().tolist()

# for it in range(max_iters):
# # WL‑CREC is *only* recomputed every 1000 edits, exactly like before
# if it % 1000 == 0:
# dataset = KGDataset(
# triples_factory,
# triples=torch.tensor(triples, dtype=torch.long),
# num_entities=n_ent,
# num_relations=n_rel,
# )
# wl = WLCREC(dataset)
# colours = wl.wl_colours(depth)
# *_, c, _ = wl.compute(H=5) # unchanged API

# if lo <= c <= hi: # success
# logger.info("[seed %d] hit %.4f in %d edits", seed, c, it)
# return triples

# # ---------- identical edit logic ----------
# if c < lo:
# if rng.random() < 0.5:
# rem_det_edge(triples, colours, depth, rng)
# else:
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng)
# else: # c > hi
# if rng.random() < 0.5:
# rem_div_edge(triples, colours, depth, rng)
# else:
# add_det_edge(triples, colours, depth, n_rel, rng)
# # ------------------------------------------

# return None # used up our budget


# # --------------------------------------------------------------------
# # Unified tuner
# # --------------------------------------------------------------------
# def tune_crec(
# wl_crec: WLCREC,
# n_ent: int,
# n_rel: int,
# depth: int,
# lo: float,
# hi: float,
# max_iters: int = 80_000,
# seed: int = 42,
# ):
# triples = wl_crec.dataset.triples.tolist()

# rng = random.Random(seed)
# for it in range(max_iters):
# print(f"\r[iter {it + 1:5d}] ", end="")
# if it % 1000 == 0:
# dataset = KGDataset(
# wl_crec.dataset.triples_factory,
# triples=torch.tensor(triples, dtype=torch.long),
# num_entities=n_ent,
# num_relations=n_rel,
# )
# tmp_wl_crec = WLCREC(dataset)

# # colours = wl_colours(triples, n_ent, depth)
# colours = tmp_wl_crec.wl_colours(depth)
# _, _, _, _, c, _ = tmp_wl_crec.compute(H=5)
# if lo <= c <= hi:
# logging.info("WL-CREC %.4f reached after %d edits (|T|=%d)", c, it, len(triples))
# return triples
# if c < lo:
# # need ↑ WL-CREC → prefer deletion of deterministic, else add diversifying
# if rng.random() < 0.5:
# rem_det_edge(triples, colours, depth, rng)
# else:
# add_div_edge(triples, colours, depth, n_ent, n_rel, rng)
# # need ↓ WL-CREC → prefer deletion of diversifying, else add deterministic
# elif rng.random() < 0.5:
# rem_div_edge(triples, colours, depth, rng)
# else:
# add_det_edge(triples, colours, depth, n_rel, rng)
# if (it + 1) % 10000 == 1:
# logging.info("[iter %d] WL-CREC %.4f |T|=%d", it + 1, c, len(triples))
# raise RuntimeError("Exceeded max iterations without hitting target band.")


# def _edit_batch(
# worker_id: int,
# n_edits: int,
# colours: List[int],
# crec: float,
# target_lo: float,
# target_hi: float,
# depth: int,
# n_ent: int,
# n_rel: int,
# seed_base: int,
# ):
# """
# Perform `n_edits` topology modifications *locally* and return
# (added_triples, removed_triples) lists.
# """
# rng = random.Random(seed_base + worker_id)
# local_added, local_removed = [], []

# # local graph views: only sizes matter for edit selection
# for _ in range(n_edits):
# if crec < target_lo: # need ↑ WL-CREC
# if rng.random() < 0.5: # • remove deterministic
# # We cannot remove by index safely without the global list;
# # choose a *signature* and remember intention to delete:
# local_removed.append(("det", rng.random()))
# else: # • add diversifying
# h = rng.randrange(n_ent)
# t = rng.randrange(n_ent)
# while colours[h] == colours[t]:
# h = rng.randrange(n_ent)
# t = rng.randrange(n_ent)
# r = rng.randrange(n_rel)
# local_added.append((h, r, t))
# else: # need ↓ WL-CREC
# if rng.random() < 0.5: # • remove diversifying
# local_removed.append(("div", rng.random()))
# else: # • add deterministic
# σ = rng.choice(colours)
# τ = rng.choice(colours)
# h = colours.index(σ)
# t = colours.index(τ)
# local_added.append((h, 0, t))

# return local_added, local_removed


# def wl_colours(triples, n_ent, depth):
# # build adjacency once
# out_adj, in_adj = defaultdict(list), defaultdict(list)
# for h, r, t in triples:
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))

# colours_rounds = [[0] * n_ent] # round-0 colours
# for h in range(1, depth + 1):
# prev = colours_rounds[-1]

# # 1) build textual signatures in parallel
# def sig(v):
# neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
# ("↑", r, prev[u]) for r, u in in_adj.get(v, [])
# ]
# neigh.sort()
# return (prev[v], tuple(neigh))

# with ThreadPoolExecutor() as tpe: # cheap threads inside worker
# sigs = list(tpe.map(sig, range(n_ent)))

# # 2) assign deterministic colour IDs
# sig2id: Dict[Tuple, int] = {}
# next_round = [0] * n_ent
# fresh = 0
# for v, sg in enumerate(sigs):
# cid = sig2id.setdefault(sg, fresh)
# if cid == fresh:
# fresh += 1
# next_round[v] = cid
# colours_rounds.append(next_round)

# depth_colours = colours_rounds[-1]

# return depth_colours


# def _metric_worker(args):
# triples, n_ent, n_rel, depth = args
# dataset = KGDataset(
# triples_factory=None, # Not used in this context
# triples=torch.tensor(triples, dtype=torch.long),
# num_entities=n_ent,
# num_relations=n_rel,
# )
# wl_crec = WLCREC(dataset)
# _, _, _, _, c, _ = wl_crec.compute(H=5, return_full=False)
# colours = wl_colours(triples, n_ent, depth)
# return c, colours


# def tune_crec_parallel_edits(
# triples_init,
# n_ent: int,
# n_rel: int,
# depth: int,
# target_lo: float,
# target_hi: float,
# max_iters: int = 80_000,
# metric_every: int = 100,
# n_workers: int = max(20, math.ceil(os.cpu_count() / 2)),
# seed: int = 42,
# ):
# # -------- shared mutable state (main thread owns it) --------
# triples = triples_init.tolist()
# out_adj, in_adj = defaultdict(list), defaultdict(list)
# for h, r, t in triples:
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))

# pool = ThreadPoolExecutor(max_workers=n_workers)
# metric_lock = threading.Lock() # exactly one metric at a time
# rng_global = random.Random(seed)

# # ----- first metric checkpoint -----
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth))

# edit_budget_total = 0
# for it in range(0, max_iters, metric_every):
# # =========================================================
# # 1. PARALLEL EDIT STAGE (metric_every edits in total)
# # =========================================================
# futures = []
# edits_per_worker = metric_every // n_workers
# extra = metric_every % n_workers
# for wid in range(n_workers):
# n_edits = edits_per_worker + (1 if wid < extra else 0)
# futures.append(
# pool.submit(
# _edit_batch,
# wid,
# n_edits,
# colours,
# crec,
# target_lo,
# target_hi,
# depth,
# n_ent,
# n_rel,
# seed,
# )
# )

# # merge when workers finish
# for fut in as_completed(futures):
# added, removed_specs = fut.result()
# # --- apply additions immediately (cheap, conflict-free)
# for h, r, t in added:
# triples.append((h, r, t))
# out_adj[h].append((r, t))
# in_adj[t].append((r, h))
# # --- apply removals: interpret the spec on *current* graph
# for typ, randv in removed_specs:
# if typ == "div":
# idxs = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
# else: # 'det'
# sig_rel = defaultdict(set)
# for h, r, t in triples:
# sig_rel[(colours[h], colours[t])].add(r)
# idxs = [
# i
# for i, (h, r, t) in enumerate(triples)
# if len(sig_rel[(colours[h], colours[t])]) == 1
# ]
# if idxs:
# victim = idxs[int(randv * len(idxs))]
# h, r, t = triples.pop(victim)
# out_adj[h].remove((r, t))
# in_adj[t].remove((r, h))

# edit_budget_total += metric_every

# # =========================================================
# # 2. SINGLE-THREADED METRIC CHECKPOINT (synchronised)
# # =========================================================
# # if edit_budget_total % (10 * metric_every) == 0:
# if True:
# with metric_lock:
# crec, colours = _metric_worker((triples, n_ent, n_rel, depth))
# logging.info(
# "After %d edits WL-CREC = %.4f |T|=%d", edit_budget_total, crec, len(triples)
# )

# if target_lo <= crec <= target_hi:
# logging.info("Target band reached.")
# pool.shutdown(wait=True)
# return triples

# pool.shutdown(wait=True)
# raise RuntimeError("Exceeded max iteration budget without success.")


# from __future__ import annotations

# import logging
# import random
# from collections import defaultdict
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from typing import List, Tuple

# import torch

# from data.kg_dataset import KGDataset
# from metrics.wlcrec import WLCREC
# from tools import get_pretty_logger

# logger = get_pretty_logger(__name__)

# # ------------ additions ------------
# def propose_add_div(triples: List, colours, n_ent: int, n_rel: int, rng: random.Random):
# for _ in range(1000):
# h, t = rng.randrange(n_ent), rng.randrange(n_ent)
# if colours[h] != colours[t]:
# r = rng.randrange(n_rel)
# return ("add", (h, r, t))
# return None # fell through – extremely rare


# def propose_add_det(colours, n_rel: int, rng: random.Random):
# σ, τ = rng.choice(colours), rng.choice(colours)
# h, t = colours.index(σ), colours.index(τ)
# return ("add", (h, 0, t)) # rel 0 = deterministic


# # ------------ removals ------------
# def propose_rem_div(triples: List, colours, rng: random.Random):
# cand = [trp for trp in triples if colours[trp[0]] != colours[trp[2]]]
# return ("rem", rng.choice(cand)) if cand else None


# def propose_rem_det(triples: List, colours, rng: random.Random):
# sig_rel = defaultdict(set)
# for h, r, t in triples:
# sig_rel[(colours[h], colours[t])].add(r)
# cand = [trp for trp in triples if len(sig_rel[(colours[trp[0]], colours[trp[2]])]) == 1]
# return ("rem", rng.choice(cand)) if cand else None


# def make_edit_proposal(
# triples_snapshot: List,
# colours_snapshot,
# c: float,
# lo: float,
# hi: float,
# n_ent: int,
# n_rel: int,
# depth: int, # still here for future use / signatures
# seed: int,
# ):
# """Return exactly one ('add' | 'rem', triple) proposal or None."""
# rng = random.Random(seed)

# # -- decide which kind of edit we want, *given* the current c ------------
# if c < lo: # ↑ WL‑CREC (delete deterministic ∨ add diversifying)
# chooser = (propose_rem_det, propose_add_div)
# else: # ↓ WL‑CREC (delete diversifying ∨ add deterministic)
# chooser = (propose_rem_div, propose_add_det)

# op = rng.choice(chooser)
# return (
# op(triples_snapshot, colours_snapshot, n_ent, n_rel, rng)
# if op.__name__.startswith("propose_add")
# else op(triples_snapshot, colours_snapshot, rng)
# )


# def tune_crec_parallel_edits(
# triples: List,
# n_ent: int,
# n_rel: int,
# depth: int,
# lo: float,
# hi: float,
# *,
# max_iters: int = 80_000,
# edits_per_eval: int = 1000, # == old “if it % 1000 == 0”
# batch_size: int = 256, # how many proposals we farm out at once
# max_workers: int = 4,
# seed: int = 42,
# ) -> List:
# rng_global = random.Random(seed)

# triples = set(triples) # deduplicate if needed

# with ThreadPoolExecutor(max_workers=max_workers) as pool:
# proposal_seed = seed * 997 # deterministic but different stream

# # edit_counter = 0
# # while edit_counter < max_iters:
# # # ----------------- expensive part (single‑thread) ----------------
# # dataset = KGDataset(
# # triples_factory=None, # Not used in this context
# # triples=torch.tensor(triples, dtype=torch.long),
# # num_entities=n_ent,
# # num_relations=n_rel,
# # )
# # tmp = WLCREC(dataset)
# # colours = tmp.wl_colours(depth)
# # *_, c, _ = tmp.compute(H=5)
# # # -----------------------------------------------------------------

# # if lo <= c <= hi:
# # logger.info(
# # "WL‑CREC %.4f reached after %d edits |T|=%d", c, edit_counter, len(triples)
# # )
# # return triples

# # ============ parallel block: just make `edits_per_eval` proposals
# needed = min(edits_per_eval, max_iters - edit_counter)
# proposals = []

# while len(proposals) < needed:
# # launch a batch of workers
# futs = [
# pool.submit(
# make_edit_proposal,
# triples,
# colours,
# c,
# lo,
# hi,
# n_ent,
# n_rel,
# depth,
# proposal_seed + i,
# )
# for i in range(batch_size)
# ]
# for f in as_completed(futs):
# prop = f.result()
# if prop is not None:
# proposals.append(prop)
# if len(proposals) == needed:
# break
# proposal_seed += batch_size # move RNG window forward

# # -------------- apply the gathered proposals *sequentially* -------
# for kind, trp in proposals:
# if kind == "add":
# triples.append(trp)
# else: # "rem"
# try:
# triples.remove(trp)
# except ValueError:
# pass # already gone – benign collision
# # -----------------------------------------------------------------
# edit_counter += needed
# if edit_counter % 1_000 == 0:
# logger.info("[iter %d] c=%.4f |T|=%d", edit_counter, c, len(triples))

# raise RuntimeError("Exceeded max_iters without hitting target band.")


import os
import random
import threading
from collections import defaultdict
from typing import Dict, List, Tuple

# --------------------------------------------------------------------
# Logging helper (replace with your own if you prefer) --------------
# --------------------------------------------------------------------
from tools import get_pretty_logger

logging = get_pretty_logger(__name__)

# --------------------------------------------------------------------
# Original edge‑editing primitives (unchanged) ----------------------
# --------------------------------------------------------------------


def add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng):
for _ in range(1000):
h = rng.randrange(n_ent)
t = rng.randrange(n_ent)
if colours[h] != colours[t]:
r = rng.randrange(n_rel)
triples.append((h, r, t))
out_adj[h].append((r, t))
in_adj[t].append((r, h))
return


def add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng):
σ = rng.choice(colours)
τ = rng.choice(colours)
h = colours.index(σ)
t = colours.index(τ)
triples.append((h, 0, t))
out_adj[h].append((0, t))
in_adj[t].append((0, h))


def rem_div_edge(triples, out_adj, in_adj, colours, depth, rng):
cand = [i for i, (h, _, t) in enumerate(triples) if colours[h] != colours[t]]
if cand:
idx = rng.choice(cand)
h, r, t = triples.pop(idx)
out_adj[h].remove((r, t))
in_adj[t].remove((r, h))


def rem_det_edge(triples, out_adj, in_adj, colours, depth, rng):
sig_rel = defaultdict(set)
for h, r, t in triples:
σ, τ = colours[h], colours[t]
sig_rel[(σ, τ)].add(r)
cand = []
for i, (h, r, t) in enumerate(triples):
if len(sig_rel[(colours[h], colours[t])]) == 1:
cand.append(i)
if cand:
idx = rng.choice(cand)
h, r, t = triples.pop(idx)
out_adj[h].remove((r, t))
in_adj[t].remove((r, h))


def _worker(
*,
worker_id: int,
rng: random.Random,
triples: List[Tuple[int, int, int]],
out_adj: Dict[int, List[Tuple[int, int]]],
in_adj: Dict[int, List[Tuple[int, int]]],
colours: List[int],
n_ent: int,
n_rel: int,
depth: int,
c: float,
lo: float,
hi: float,
max_iters: int,
state_lock: threading.Lock,
stop_event: threading.Event,
):
"""One thread: mutate the *shared* structures until success/stop."""
for it in range(max_iters):
if stop_event.is_set():
return # someone else finished ─ exit early

with state_lock: # protect the shared graph
# if lo <= c <= hi:
# logging.info(
# "[worker %d] converged after %d steps (CREC %.4f, |T|=%d)",
# worker_id,
# it,
# c,
# len(triples),
# )
# stop_event.set()
# return

# Choose and apply one edit -----------------------------
if c < lo: # need ↑ CREC
if rng.random() < 0.5:
rem_det_edge(triples, out_adj, in_adj, colours, depth, rng)
else:
add_div_edge(triples, out_adj, in_adj, colours, depth, n_ent, n_rel, rng)
elif rng.random() < 0.5:
rem_div_edge(triples, out_adj, in_adj, colours, depth, rng)
else:
add_det_edge(triples, out_adj, in_adj, colours, depth, n_rel, rng)

logging.warning("[worker %d] reached max_iters", worker_id)
stop_event.set()
return


# --------------------------------------------------------------------
# Public API --------------------------------------------------------
# --------------------------------------------------------------------


def fast_tune_crec(
triples: List,
colours: List,
n_ent: int,
n_rel: int,
depth: int,
c: float,
lo: float,
hi: float,
max_iters: int = 1000,
max_workers: int | None = None,
seeds: List[int] | None = None,
) -> List[Tuple[int, int, int]]:
"""Tune WL‑CREC with *shared* triples using multiple threads.

Returns the **same list instance** that was passed in – already
modified in place by the winning thread.
"""
if max_workers is None:
# max_workers = os.cpu_count() or 4
max_workers = 4
if seeds is None:
seeds = [42 + i for i in range(max_workers)]
assert len(seeds) >= max_workers, "Need at least one seed per worker"

# Prepare adjacency once (shared) --------------------------------
out_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
in_adj: Dict[int, List[Tuple[int, int]]] = defaultdict(list)
for h, r, t in triples:
out_adj[h].append((r, t))
in_adj[t].append((r, h))

state_lock = threading.Lock()
stop_event = threading.Event()

logging.info(
"Launching %d threads on shared triples (target %.3f–%.3f)",
max_workers,
lo,
hi,
)

threads: List[threading.Thread] = []
for wid in range(max_workers):
t = threading.Thread(
name=f"tune‑crec‑worker‑{wid}",
target=_worker,
kwargs=dict(
worker_id=wid,
rng=random.Random(seeds[wid]),
triples=triples,
out_adj=out_adj,
in_adj=in_adj,
colours=colours,
n_ent=n_ent,
n_rel=n_rel,
depth=depth,
c=c,
lo=lo,
hi=hi,
max_iters=max_iters,
state_lock=state_lock,
stop_event=stop_event,
),
daemon=False,
)
threads.append(t)
t.start()

for t in threads:
t.join()

if not stop_event.is_set():
raise RuntimeError("No thread converged – try increasing max_iters or widen band")

return triples

+ 106
- 0
metrics/crec_radius_sample.py View File

from __future__ import annotations

import logging
import random
from collections import defaultdict, deque
from math import isclose
from typing import List

import torch

from data.kg_dataset import KGDataset
from metrics.wlcrec import WLCREC
from tools import get_pretty_logger

logging = get_pretty_logger(__name__)



# --------------------------------------------------------------------
# Radius-k patch sampler
# --------------------------------------------------------------------
def radius_patch_sample(
all_triples: List,
ratio: float,
radius: int,
rng: random.Random,
) -> List:
"""
Take ~ratio fraction of triples by picking random seeds and
keeping *all* triples within ≤ radius hops (undirected) of them.
"""
n_total = len(all_triples)
target = int(ratio * n_total)

# Build entity → triple-indices adjacency for BFS
ent2trip = defaultdict(list)
for idx, (h, _, t) in enumerate(all_triples):
ent2trip[h].append(idx)
ent2trip[t].append(idx)

chosen_idx: set[int] = set()
visited_ent = set()

while len(chosen_idx) < target:
seed_idx = rng.randrange(n_total)
if seed_idx in chosen_idx:
continue
# BFS over entities
queue = deque()
for e in all_triples[seed_idx][::2]: # subject & object
if e not in visited_ent:
queue.append((e, 0))
visited_ent.add(e)
while queue:
ent, dist = queue.popleft()
for tidx in ent2trip[ent]:
if tidx not in chosen_idx:
chosen_idx.add(tidx)
if dist < radius:
for tidx in ent2trip[ent]:
h, _, t = all_triples[tidx]
for nb in (h, t):
if nb not in visited_ent:
visited_ent.add(nb)
queue.append((nb, dist + 1))
# guard against infinite loop on very small ratio
if isclose(ratio, 0.0) and chosen_idx:
break

return [all_triples[i] for i in chosen_idx]


# --------------------------------------------------------------------
# Main driver: repeated sampling until WL-CREC band hit
# --------------------------------------------------------------------
def find_sample(
all_triples: List,
n_ent: int,
n_rel: int,
depth: int,
ratio: float,
radius: int,
lower: float,
upper: float,
max_tries: int,
seed: int,
) -> List:
for trial in range(1, max_tries + 1):
rng = random.Random(seed + trial) # fresh randomness each try
sample = radius_patch_sample(all_triples, ratio, radius, rng)

sample_ds = KGDataset(
triples_factory=None,
triples=torch.tensor(sample, dtype=torch.long),
num_entities=n_ent,
num_relations=n_rel,
split="train",
)
wl_crec = WLCREC(sample_ds)

crec_val = wl_crec.compute(depth)[-2]
logging.info("Try %d |T'|=%d WL-CREC=%.4f", trial, len(sample), crec_val)
if lower <= crec_val <= upper:
logging.info("Success after %d tries.", trial)
return sample
raise RuntimeError(f"No sample reached WL-CREC band in {max_tries} tries.")

+ 226
- 0
metrics/greedy_crec.py View File

from __future__ import annotations

import logging
import random
from math import log
from typing import Dict, List, Tuple

import torch
from torch import Tensor

from tools import get_pretty_logger

logging = get_pretty_logger(__name__)


Entity = int
Relation = int
Triple = Tuple[Entity, Relation, Entity]
Color = int


# ---------------------------------------------------------------------
# Weisfeiler–Lehman colouring (GPU friendly)
# ---------------------------------------------------------------------
def wl_colours_gpu(triples: Tensor, n_ent: int, depth: int, device: torch.device) -> List[Tensor]:
"""
GPU implementation of classic Weisfeiler‑Lehman refinement.

1. Each node starts with colour 0.
2. At iteration h we hash the multiset of
⬇︎ (r, colour(t)) and ⬆︎ (r, colour(h))
signatures into new colour ids.
3. Hashing is done with a fast 64‑bit mix; collisions are unlikely and
immaterial for CREC (only *relative* colours matter).
"""
h, r, t = triples.unbind(dim=1) # (|T|,)
colours = [torch.zeros(n_ent, dtype=torch.long, device=device)] # C⁰ = 0

# pre‑compute 64‑bit mix constants once
mix_r = (
torch.arange(triples[:, 1].max() + 1, device=device, dtype=torch.long)
* 1_146_189_683_093_321_123
)
mix_c = 8_636_673_225_737_527_201

for _ in range(depth):
prev = colours[-1] # (|V|,)

# signatures for outgoing edges
sig_out = mix_r[r] ^ (prev[t] * mix_c)
# signatures for incoming edges
sig_in = mix_r[r] ^ (prev[h] * mix_c) ^ 0x9E3779B97F4A7C15

# bucket signatures back to the source/target nodes
col_out = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, h, sig_out)
col_in = torch.zeros(n_ent, dtype=torch.long, device=device).index_add_(0, t, sig_in)

# combine ↓ and ↑ multiset hashes with current colour
raw = (prev * 3_205_813_371) ^ col_out ^ (col_in << 1)

# re‑map to dense consecutive ids with torch.unique
uniq, new = torch.unique(raw, sorted=True, return_inverse=True)
colours.append(new)

return colours # length = depth+1, each (|V|,) long


# ---------------------------------------------------------------------
# WL‑CREC metric
# ---------------------------------------------------------------------
def wl_crec_gpu(
triples: Tensor, colours: List[Tensor], n_ent: int, n_rel: int, depth: int, device: torch.device
) -> float:
log_n = log(max(n_ent, 2))

# ------- pattern‑diversity C -------------------------------------
C = 0.0
for c in colours: # (|V|,)
# histogram via bincount on GPU
hist = torch.bincount(c).float()
p = hist / n_ent
C += -(p * torch.log(p.clamp_min(1e-30))).sum().item()
C /= (depth + 1) * log_n

# ------- residual entropy H_c ------------------------------------
h, r, t = triples.unbind(dim=1)
sigma, tau = colours[depth][h], colours[depth][t] # (|T|,)

# encode (sigma,tau) pairs into a single 64‑bit key
sig_keys = (sigma.to(torch.int64) << 32) + tau.to(torch.int64)

# total count per signature
sig_unique, sig_inv, sig_counts = torch.unique(
sig_keys, return_inverse=True, return_counts=True, sorted=False
)

# build 2‑D contingency table counts[(sigma,tau), r]
m = triples.size(0)
keys_2d = sig_inv * n_rel + r
_, rel_counts = torch.unique(keys_2d, return_counts=True, sorted=False)

# rel_counts is aligned with the *compact* key list; rebuild dense tensor
rc_dense = torch.zeros(len(sig_unique), n_rel, device=device, dtype=torch.long)
rc_dense.scatter_add_(0, keys_2d.unsqueeze(1), torch.ones(m, device=device, dtype=torch.long))

# conditional entropy per (sigma,tau)
p_r = rc_dense.float() / sig_counts.unsqueeze(1).float()
inner = -(p_r * torch.log(p_r.clamp_min(1e-30))).sum(dim=1) # (|sigma|,)

Hc = (sig_counts.float() / m * inner).sum().item() / log(max(n_rel, 2))

return C * Hc


# ---------------------------------------------------------------------
# delta‑entropy helpers (still CPU side for clarity; cost is negligible)
# ---------------------------------------------------------------------
# The deterministic delta‑formulas depend on small dictionaries and are evaluated
# on ≤256 candidate edges each iteration – copy them from the original script.
inner_entropy = ... # unchanged
delta_h_cond_remove = ... # unchanged
delta_h_cond_add = ... # unchanged


# ---------------------------------------------------------------------
# Greedy delta‑search driver
# ---------------------------------------------------------------------
def greedy_delta_tune_gpu(
triples: List[Triple],
n_ent: int,
n_rel: int,
depth: int,
lower: float,
upper: float,
device: torch.device,
max_iters: int = 40_000,
sample_size: int = 256,
seed: int = 0,
) -> List[Triple]:
# book‑keeping in Python lists (cheap); heavy math in torch on GPU
rng = random.Random(seed)

# cache torch version of triples for fast metric eval
def triples_to_tensor(ts: List[Triple]) -> Tensor:
return torch.tensor(ts, dtype=torch.long, device=device, requires_grad=False)

for it in range(max_iters):
tri_tensor = triples_to_tensor(triples)
colours = wl_colours_gpu(tri_tensor, n_ent, depth, device)
crec = wl_crec_gpu(tri_tensor, colours, n_ent, n_rel, depth, device)

if lower <= crec <= upper:
logging.info("WL‑CREC %.4f reached after %d edits (|T|=%d)", crec, it, len(triples))
return triples

# ----------------------------------------------------------------
# build signature statistics (CPU, cheap: O(|T|))
# ----------------------------------------------------------------
depth_col = colours[depth].cpu()
sig_cnt: Dict[Tuple[int, int], int] = {}
rel_cnt: Dict[Tuple[int, int, int], int] = {}

for h_idx, (h, r, t) in enumerate(triples):
sigma, tau = int(depth_col[h]), int(depth_col[t])
sig_cnt[(sigma, tau)] = sig_cnt.get((sigma, tau), 0) + 1
rel_cnt[(sigma, tau, r)] = rel_cnt.get((sigma, tau, r), 0) + 1

det_edges, div_edges = [], []
for idx, (h, r, t) in enumerate(triples):
sigma, tau = int(depth_col[h]), int(depth_col[t])
if sum(rel_cnt.get((sigma, tau, rr), 0) > 0 for rr in range(n_rel)) == 1:
det_edges.append(idx)
if sigma != tau:
div_edges.append(idx)

# ----------------------------------------------------------------
# candidate generation + best edit selection
# ----------------------------------------------------------------
total = len(triples)
target_high = crec < lower # need to raise WL‑CREC?
candidates = []

if target_high and det_edges:
rng.shuffle(det_edges)
for idx in det_edges[:sample_size]:
h, r, t = triples[idx]
sig = (int(depth_col[h]), int(depth_col[t]))
delta = delta_h_cond_remove(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
if delta > 0:
candidates.append(("remove", idx, delta))

elif not target_high and det_edges:
rng.shuffle(det_edges)
for idx in det_edges[:sample_size]:
h, r, t = triples[idx]
sig = (int(depth_col[h]), int(depth_col[t]))
delta = delta_h_cond_add(sig, r, sig_cnt, rel_cnt, total, crec, n_rel)
if delta < 0:
candidates.append(("add", (h, r, t), delta))

# fall‑back heuristics
if not candidates:
if target_high and div_edges:
idx = rng.choice(div_edges)
candidates.append(("remove", idx, 1e-9))
elif not target_high:
idx = rng.choice(det_edges) if det_edges else rng.randrange(total)
h, r, t = triples[idx]
candidates.append(("add", (h, r, t), -1e-9))

best = (
max(candidates, key=lambda x: x[2])
if target_high
else min(candidates, key=lambda x: x[2])
)

# apply edit
if best[0] == "remove":
triples.pop(best[1])
else:
triples.append(best[1])

if (it + 1) % 1_000 == 0:
logging.info("[iter %d] WL‑CREC %.4f |T|=%d", it + 1, crec, len(triples))

raise RuntimeError("Max iterations exceeded without hitting WL‑CREC band.")

+ 30
- 0
metrics/ranking.py View File

import torch

from models.base_model import ERModel


def _rank(scores: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return (scores > scores[target]).sum() + 1


def simple_ranking(model: ERModel, triples: torch.Tensor) -> torch.Tensor:
h, r, t = triples[:, 0], triples[:, 1], triples[:, 2]
E = model.num_entities
ranks = []

with torch.no_grad():
for hi, ri, ti in zip(h, r, t):
tails = torch.arange(E, device=triples.device)
triples = torch.stack([hi.repeat(E), ri.repeat(E), tails], dim=1)
scores = model(triples)
ranks.append(_rank(scores, ti))

return torch.tensor(ranks, device=triples.device, dtype=torch.float32)


def mrr(r: torch.Tensor) -> torch.Tensor:
return (1 / r).mean()


def hits_at_k(r: torch.Tensor, k: int = 10) -> torch.Tensor:
return (r <= k).float().mean()

+ 237
- 0
metrics/wlcrec.py View File

from __future__ import annotations

import math
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple

import torch

from .base_metric import BaseMetric

if TYPE_CHECKING:
from data.kg_dataset import KGDataset


def _entropy_from_counter(counter: Counter[int], n: int) -> float:
"""Return Shannon entropy (nats) of a colour distribution of length *n*."""

ent = 0.0
for cnt in counter.values():
if cnt:
p = cnt / n
ent -= p * math.log(p)
return ent


def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
"""Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies."""

if n <= 1:
return 0.0, 0.0
H_plus_1 = len(entropies)
log_n = math.log(n)
c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1
c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1
return c_ratio, c_nwlec


def _relation_families(
triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9
) -> torch.Tensor:
"""Return a LongTensor mapping each relation id → family id.

Two relations *r, r_inv* are put into the same family when **at least
`thresh` fraction** of the edges labelled *r* have a **single** reverse edge
labelled *r_inv*.
"""
assert 0.0 < thresh <= 1.0

# Build map (h,t) -> list of relations on that edge direction.
edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list)
for h, r, t in triples.tolist():
edge2rels[(h, t)].append(int(r))

# Count reciprocal co‑occurrences (r, r_rev).
pair_counts: Dict[Tuple[int, int], int] = defaultdict(int)
rel_totals: List[int] = [0] * num_relations
for (h, t), rels in edge2rels.items():
rev_rels = edge2rels.get((t, h))
if not rev_rels:
continue
for r in rels:
rel_totals[r] += 1
for r_rev in rev_rels:
pair_counts[(r, r_rev)] += 1

# Proposed inverse mapping r -> r_inv.
proposed: List[int | None] = [None] * num_relations
for (r, r_rev), cnt in pair_counts.items():
if cnt == 0:
continue
if cnt / rel_totals[r] >= thresh:
# majority of r's edges are reversed by r_rev
proposed[r] = r_rev

# Build family ids (transitive closure is overkill; data are simple).
family = list(range(num_relations))
for r, rinv in enumerate(proposed):
if rinv is not None:
root = min(family[r], family[rinv])
family[r] = family[rinv] = root
return torch.tensor(family, dtype=torch.long)


def _conditional_relation_entropy(
colours: List[int],
triples: torch.Tensor,
family_map: torch.Tensor,
) -> float:
counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))
m = int(triples.size(0))
fam = family_map # alias for speed

for h, r, t in triples.tolist():
# key = (colours[h], colours[t]) # ordered key (direction‑aware)
key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic)
counts[key][int(fam[r])] += 1 # ▼ use family id

h_cond = 0.0
for rel_counts in counts.values():
total_s = sum(rel_counts.values())
P_s = total_s / m
inv_total = 1.0 / total_s
for cnt in rel_counts.values():
p_rs = cnt * inv_total
h_cond -= P_s * p_rs * math.log(p_rs)
return h_cond


class WLCREC(BaseMetric):
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""

def __init__(self, dataset: KGDataset) -> None:
super().__init__(dataset)

def compute(
self,
H: int = 3,
cond_h: int = 1,
inv_thresh: float = 0.9,
return_full: bool = False,
progress: bool = True,
) -> float | Tuple[float, List[float]]:
"""Compute WL-entropy scores and the composite difficulty metric.

Parameters
----------
dataset
Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``.
H
WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers.
return_full
Also return the list of per-layer entropies.
progress
Print textual progress bar.

Returns
-------
avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC
*Six* scalars (floats). If ``return_full=True`` an extra list of
layer entropies is appended.
"""

# compute relation family map once
family_map = _relation_families(
self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh
)

# ─── WL iterations ────────────────────────────────────────────────────

n = self.dataset.num_entities
colour: List[int] = [1] * n # colour 1 for everyone
entropies: List[float] = []
colours_per_h: List[List[int]] = []

def _print_prog(i: int) -> None:
if progress:
print(f"\rWL iteration {i}/{H}", end="", flush=True)

for h in range(H + 1):
_print_prog(h)

colours_per_h.append(colour.copy())

# entropy of current colouring
freq = Counter(colour)
entropies.append(_entropy_from_counter(freq, n))

if h == H:
break

# refine colours
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
next_colour: List[int] = [0] * n

for v in range(n):
T: List[Tuple[int, int, int]] = []
for r, t in self.out_adj[v]:
T.append((r, 0, colour[t])) # outgoing
for r, h_ in self.in_adj[v]:
T.append((r, 1, colour[h_])) # incoming
T.sort()
key = (colour[v], tuple(T))
if key not in bucket:
bucket[key] = len(bucket) + 1
next_colour[v] = bucket[key]

colour = next_colour

_print_prog(H)
if progress:
print()

# ─── Normalised diversity ────────────────────────────────────────────
avg_entropy = sum(entropies) / len(entropies)
c_ratio, c_nwlec = _normalise_scores(entropies, n)

# ─── Residual relation entropy ───────────────────────────────────────
h_cond = _conditional_relation_entropy(
colours_per_h[cond_h], self.dataset.triples, family_map
)

# ─── Composite difficulty - Eq. (1) ──────────────────────────────────
d_ratio = c_ratio * h_cond
d_nwlec = c_nwlec * h_cond

result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec)
if return_full:
result += (entropies,)
return result

# --------------------------------------------------------------------
# Weisfeiler–Lehman colours (unchanged)
# --------------------------------------------------------------------
def wl_colours(self, depth: int) -> List[List[int]]:
triples = self.dataset.triples.tolist()
out_adj, in_adj = defaultdict(list), defaultdict(list)
for h, r, t in triples:
out_adj[h].append((r, t))
in_adj[t].append((r, h))

n_ent = self.dataset.num_entities
colours = [[0] * n_ent] # round-0
for h in range(1, depth + 1):
prev, nxt, sig2c = colours[-1], [0] * n_ent, {}
fresh = 0
for v in range(n_ent):
neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
("↑", r, prev[u]) for r, u in in_adj.get(v, [])
]
neigh.sort()
sig = (prev[v], tuple(neigh))
if sig not in sig2c:
sig2c[sig] = fresh
fresh += 1
nxt[v] = sig2c[sig]
colours.append(nxt)
return colours

+ 135
- 0
metrics/wlec.py View File

from __future__ import annotations

import math
from collections import Counter
from typing import TYPE_CHECKING, Dict, List, Tuple

from .base_metric import BaseMetric

if TYPE_CHECKING:
from data.kg_dataset import KGDataset


def _entropy_from_counter(counter: Counter[int], n: int) -> float:
"""Shannon entropy ``H = -Σ pᵢ log pᵢ`` in *nats* for the given colour counts."""
ent = 0.0
for count in counter.values():
if count: # avoid log(0)
p = count / n
ent -= p * math.log(p)
return ent


def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
"""Return (C_ratio, C_NWLEC) given *per-iteration* entropies."""

if n <= 1:
# Degenerate graph.
return 0.0, 0.0

H_plus_1 = len(entropies)
# Eq. (1): simple ratio normalisation.
c_ratio = sum(ent / math.log(n) for ent in entropies) / H_plus_1

# Eq. (2): effective-colour NWLEC.
k_terms = [(math.exp(ent) - 1) / (n - 1) for ent in entropies]
c_nwlec = sum(k_terms) / H_plus_1
return c_ratio, c_nwlec


class WLEC(BaseMetric):
"""Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""

def __init__(self, dataset: KGDataset) -> None:
super().__init__(dataset)

def compute(
self,
H: int = 3,
*,
return_full: bool = False,
progress: bool = True,
) -> float | Tuple[float, List[float]]:
"""Compute the *average* Weisfeiler-Lehman Entropy Complexity (WLEC).

Parameters
----------
dataset:
Any :class:`KGDataset`-like object exposing ``triples`` (``LongTensor``),
``num_entities`` and ``num_relations``.
H:
Number of refinement iterations **after** the initial colouring (so the
colours are updated **H** times and the entropy is measured **H+1** times).
return_full:
If *True*, additionally return the list of entropies for each depth.
progress:
Whether to print a mini progress bar (no external dependency).

Returns
-------
average_entropy or (average_entropy, entropies)
Average of the stored entropies, and optionally the list itself.
"""
n = int(self.num_entities)
if n == 0:
msg = "Dataset appears to contain zero entities."
raise ValueError(msg)

# Colour assignment - we keep it as a *list* of ints for cheap hashing.
# Start with colour 1 for every entity (index 0 unused).
colour: List[int] = [1] * n

entropies: List[float] = []

# Optional poor-man's progress bar.
def _print_prog(i: int) -> None:
if progress:
print(f"\rIteration {i}/{H}", end="", flush=True)

for h in range(H + 1):
_print_prog(h)

# ------------------------------------------------------------------
# Step 1: measure entropy of current colouring ----------------------
# ------------------------------------------------------------------
freq = Counter(colour)
ent = _entropy_from_counter(freq, n)
entropies.append(ent)

if h == H:
break # We have reached the requested depth - stop here.

# ------------------------------------------------------------------
# Step 2: create refined colours -----------------------------------
# ------------------------------------------------------------------
bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
next_colour: List[int] = [0] * n

for v in range(n):
# Build the multiset T containing outgoing and incoming features.
T: List[Tuple[int, int, int]] = []
for r, t in self.out_adj[v]:
T.append((r, 0, colour[t])) # 0 = outgoing (\u2193)
for r, h_ in self.in_adj[v]:
T.append((r, 1, colour[h_])) # 1 = incoming (\u2191)
T.sort() # canonical ordering

key = (colour[v], tuple(T))
if key not in bucket:
bucket[key] = len(bucket) + 1 # new colour ID (start at 1)
next_colour[v] = bucket[key]

colour = next_colour # move to next iteration

_print_prog(H)
if progress:
print() # newline after progress bar

average_entropy = sum(entropies) / len(entropies)

c_ratio, c_nwlec = _normalise_scores(entropies, n)

if return_full:
return average_entropy, entropies, c_ratio, c_nwlec

return average_entropy, c_ratio, c_nwlec

+ 0
- 0
models/__init__.py View File


+ 127
- 0
models/base_model.py View File

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn
from torch.nn.functional import log_softmax, softmax

if TYPE_CHECKING:
from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target


class ERModel(nn.Module, ABC):
"""Base class for knowledge graph models."""

def __init__(self, num_entities: int, num_relations: int, dim: int) -> None:
super().__init__()
self.num_entities = num_entities
self.num_relations = num_relations
self.dim = dim

self.inner_model = None

@property
def device(self) -> torch.device:
"""Return the device of the model."""
return next(self.inner_model.parameters()).device

@abstractmethod
def reset_parameters(self, *args, **kwargs) -> None:
"""Reset the parameters of the model."""

def predict(
self,
hrt_batch: MappedTriples,
target: Target,
full_batch: bool = True,
ids: LongTensor | None = None,
**kwargs,
) -> FloatTensor:
if not self.inner_model:
msg = (
"Inner model is not set. Please initialize the inner model before calling predict."
)
raise ValueError(
msg,
)

return self.inner_model.predict(
hrt_batch=hrt_batch,
target=target,
full_batch=full_batch,
ids=ids,
**kwargs,
)

def forward(
self,
triples: MappedTriples,
slice_size: int | None = None,
slice_dim: int = 0,
*,
mode: InductiveMode | None,
) -> FloatTensor:
h_indices = triples[:, 0]
r_indices = triples[:, 1]
t_indices = triples[:, 2]

if not self.inner_model:
msg = (
"Inner model is not set. Please initialize the inner model before calling forward."
)
raise ValueError(
msg,
)

return self.inner_model.forward(
h_indices=h_indices,
r_indices=r_indices,
t_indices=t_indices,
slice_size=slice_size,
slice_dim=slice_dim,
mode=mode,
)

@torch.inference_mode()
def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r)
self,
hr_batch: LongTensor, # (B, 2) : (h,r)
*,
slice_size: int | None = None,
mode: InductiveMode | None = None,
log: bool = False,
) -> FloatTensor: # (B, |E|)
scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode)
return log_softmax(scores, -1) if log else softmax(scores, -1)

# ----------------------------------------------------------------------- #
# (1) HEAD-given-(relation, tail) → pθ(h | r,t) #
# ----------------------------------------------------------------------- #
@torch.inference_mode()
def head_distribution(
self,
rt_batch: LongTensor, # (B, 2) : (r,t)
*,
slice_size: int | None = None,
mode: InductiveMode | None = None,
log: bool = False,
) -> FloatTensor: # (B, |E|)
scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode)
return log_softmax(scores, -1) if log else softmax(scores, -1)

# ----------------------------------------------------------------------- #
# (2) RELATION-given-(head, tail) → pθ(r | h,t) #
# ----------------------------------------------------------------------- #
@torch.inference_mode()
def relation_distribution(
self,
ht_batch: LongTensor, # (B, 2) : (h,t)
*,
slice_size: int | None = None,
mode: InductiveMode | None = None,
log: bool = False,
) -> FloatTensor: # (B, |R|)
scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode)
return log_softmax(scores, -1) if log else softmax(scores, -1)

+ 0
- 0
models/translation/__init__.py View File


BIN
models/translation/__pycache__/__init__.cpython-310.pyc View File


BIN
models/translation/__pycache__/trans_e.cpython-310.pyc View File


+ 70
- 0
models/translation/trans_e.py View File

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from pykeen.models import TransE as PyKEENTransE
from torch import nn

from models.base_model import ERModel

if TYPE_CHECKING:
from pykeen.typing import MappedTriples


class TransE(ERModel):
"""
TransE model for knowledge graph embedding.

score = ||h + r - t||
"""

def __init__(
self,
num_entities: int,
num_relations: int,
triples_factory: MappedTriples,
dim: int = 200,
sf_norm: int = 1,
p_norm: bool = True,
margin: float | None = None,
epsilon: float | None = None,
device: torch.device = torch.device("cpu"),
) -> None:
super().__init__(num_entities, num_relations, dim)
self.dim = dim
self.sf_norm = sf_norm
self.p_norm = p_norm
self.margin = margin
self.epsilon = epsilon

self.inner_model = PyKEENTransE(
triples_factory=triples_factory,
embedding_dim=dim,
scoring_fct_norm=sf_norm,
power_norm=p_norm,
random_seed=42,
).to(device)

def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
# Parameter initialization
if margin is None or epsilon is None:
# If no margin/epsilon are provided, use Xavier uniform initialization
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
self.embedding_range = nn.Parameter(
torch.Tensor([(margin + epsilon) / self.dim]),
requires_grad=False,
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)

+ 76
- 0
models/translation/trans_h.py View File

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from pykeen.losses import MarginRankingLoss
from pykeen.models import TransH as PyKEENTransH
from pykeen.regularizers import LpRegularizer
from torch import nn

from models.base_model import ERModel

if TYPE_CHECKING:
from pykeen.typing import MappedTriples


class TransH(ERModel):
"""
TransH model for knowledge graph embedding.
"""

def __init__(
self,
num_entities: int,
num_relations: int,
triples_factory: MappedTriples,
dim: int = 200,
p_norm: bool = True,
margin: float | None = None,
epsilon: float | None = None,
device: torch.device = torch.device("cpu"),
**kwargs,
) -> None:
super().__init__(num_entities, num_relations, dim)
self.dim = dim
self.p_norm = p_norm
self.margin = margin
self.epsilon = epsilon

self.inner_model = PyKEENTransH(
triples_factory=triples_factory,
embedding_dim=dim,
scoring_fct_norm=1,
loss=self.loss,
power_norm=p_norm,
entity_initializer="xavier_uniform_", # good default
relation_initializer="xavier_uniform_",
random_seed=42,
).to(device)

@property
def loss(self) -> torch.Tensor:
return MarginRankingLoss(margin=self.margin, reduction="mean")

def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
# Parameter initialization
if margin is None or epsilon is None:
# If no margin/epsilon are provided, use Xavier uniform initialization
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
self.embedding_range = nn.Parameter(
torch.Tensor([(margin + epsilon) / self.dim]),
requires_grad=False,
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)

+ 79
- 0
models/translation/trans_r.py View File

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from pykeen.losses import MarginRankingLoss
from pykeen.models import TransR as PyKEENTransR
from torch import nn

from models.base_model import ERModel

if TYPE_CHECKING:
from pykeen.typing import MappedTriples


class TransR(ERModel):
"""
TransR model for knowledge graph embedding.
"""

def __init__(
self,
num_entities: int,
num_relations: int,
triples_factory: MappedTriples,
entity_dim: int = 200,
relation_dim: int = 30,
p_norm: bool = True,
p_norm_value: int = 2,
margin: float | None = None,
epsilon: float | None = None,
device: torch.device = torch.device("cpu"),
**kwargs,
) -> None:
super().__init__(num_entities, num_relations, entity_dim)
self.entity_dim = entity_dim
self.relation_dim = relation_dim
self.p_norm = p_norm
self.margin = margin
self.epsilon = epsilon

self.inner_model = PyKEENTransR(
triples_factory=triples_factory,
embedding_dim=entity_dim,
relation_dim=relation_dim,
scoring_fct_norm=1,
loss=self.loss,
power_norm=p_norm,
entity_initializer="xavier_uniform_",
relation_initializer="xavier_uniform_",
random_seed=42,
).to(device)

@property
def loss(self) -> torch.Tensor:
return MarginRankingLoss(margin=self.margin, reduction="mean")

def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
# Parameter initialization
if margin is None or epsilon is None:
# If no margin/epsilon are provided, use Xavier uniform initialization
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
else:
# Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
self.embedding_range = nn.Parameter(
torch.Tensor([(margin + epsilon) / self.dim]),
requires_grad=False,
)
nn.init.uniform_(
tensor=self.ent_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)
nn.init.uniform_(
tensor=self.rel_embeddings.weight.data,
a=-self.embedding_range.item(),
b=self.embedding_range.item(),
)

+ 82
- 0
pyproject.toml View File

[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]
respect-gitignore = true

# Black line length is 88
line-length = 100
indent-width = 4

[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["ALL"]
ignore = ["ANN002", "ANN003", "RET504", "ERA001", "E741", "N812", "D", "PLR0913", "T201", "FBT001", "FBT002", "PLW0602", "PLW0603", "PTH", "C901", "PLR2004", "UP006", "UP035", "G004", "N803", "N806", "ARG002"]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E", "F", "I", "N"]

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

# Allow unused variables when underscore-prefixed.
#dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = false

# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

[tool.pyright]
typeCheckingMode = "off"

+ 28
- 0
setup/install.sh View File

#!/usr/bin/env bash
# install.sh – minimal Conda bootstrap with ~/.bashrc check
set -e

ENV=kg-env # must match your environment.yml
YAML=kg_env.yaml # assumed to be in the current dir

# ── 1. try to load an existing conda from ~/.bashrc ──────────────────────
[ -f "$HOME/.bashrc" ] && source "$HOME/.bashrc" || true

# ── 2. ensure conda exists ───────────────────────────────────────────────
if ! command -v conda &>/dev/null; then
echo "[install.sh] Conda not found → installing Miniconda."
curl -fsSL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o miniconda.sh
bash miniconda.sh -b -p "$HOME/miniconda3"
rm miniconda.sh
source "$HOME/miniconda3/etc/profile.d/conda.sh"
conda init bash >/dev/null # so future shells have it
else
# conda exists; load its helper for this shell
source "$(conda info --base)/etc/profile.d/conda.sh"
fi

# ── 3. create or update the project env ──────────────────────────────────
conda env create -n "$ENV" -f "$YAML" \
|| conda env update -n "$ENV" -f "$YAML" --prune

echo "✔ Done. Activate with: conda activate $ENV"

+ 163
- 0
setup/kg_env.yaml View File

# environment.yml ── KG-Embeddings project
name: kg
channels:
- pytorch
- defaults
- nvidia
- conda-forge
- https://repo.anaconda.com/pkgs/main
- https://repo.anaconda.com/pkgs/r
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- aom=3.6.1=h59595ed_0
- blas=1.0=mkl
- brotli-python=1.1.0=py311hfdbb021_2
- bzip2=1.0.8=h4bc722e_7
- ca-certificates=2024.12.14=hbcca054_0
- certifi=2024.12.14=pyhd8ed1ab_0
- cffi=1.17.1=py311hf29c0ef_0
- charset-normalizer=3.4.1=pyhd8ed1ab_0
- cpython=3.11.11=py311hd8ed1ab_1
- cuda-cudart=12.4.127=0
- cuda-cupti=12.4.127=0
- cuda-libraries=12.4.1=0
- cuda-nvrtc=12.4.127=0
- cuda-nvtx=12.4.127=0
- cuda-opencl=12.6.77=0
- cuda-runtime=12.4.1=0
- cuda-version=12.6=3
- ffmpeg=4.4.2=gpl_hdf48244_113
- filelock=3.16.1=pyhd8ed1ab_1
- font-ttf-dejavu-sans-mono=2.37=hab24e00_0
- font-ttf-inconsolata=3.000=h77eed37_0
- font-ttf-source-code-pro=2.038=h77eed37_0
- font-ttf-ubuntu=0.83=h77eed37_3
- fontconfig=2.15.0=h7e30c49_1
- fonts-conda-ecosystem=1=0
- fonts-conda-forge=1=0
- freetype=2.12.1=h267a509_2
- gettext=0.22.5=he02047a_3
- gettext-tools=0.22.5=he02047a_3
- giflib=5.2.2=hd590300_0
- gmp=6.3.0=hac33072_2
- gmpy2=2.1.5=py311h0f6cedb_3
- gnutls=3.7.9=hb077bed_0
- h2=4.1.0=pyhd8ed1ab_1
- hpack=4.0.0=pyhd8ed1ab_1
- hyperframe=6.0.1=pyhd8ed1ab_1
- idna=3.10=pyhd8ed1ab_1
- intel-openmp=2022.0.1=h06a4308_3633
- jinja2=3.1.5=pyhd8ed1ab_0
- lame=3.100=h166bdaf_1003
- lcms2=2.16=hb7c19ff_0
- ld_impl_linux-64=2.43=h712a8e2_2
- lerc=4.0.0=h27087fc_0
- libasprintf=0.22.5=he8f35ee_3
- libasprintf-devel=0.22.5=he8f35ee_3
- libblas=3.9.0=16_linux64_mkl
- libcblas=3.9.0=16_linux64_mkl
- libcublas=12.4.5.8=0
- libcufft=11.2.1.3=0
- libcufile=1.11.1.6=0
- libcurand=10.3.7.77=0
- libcusolver=11.6.1.9=0
- libcusparse=12.3.1.170=0
- libdeflate=1.23=h4ddbbb0_0
- libdrm=2.4.124=hb9d3cd8_0
- libegl=1.7.0=ha4b6fd6_2
- libexpat=2.6.4=h5888daf_0
- libffi=3.4.2=h7f98852_5
- libgcc=14.2.0=h77fa898_1
- libgcc-ng=14.2.0=h69a702a_1
- libgettextpo=0.22.5=he02047a_3
- libgettextpo-devel=0.22.5=he02047a_3
- libgl=1.7.0=ha4b6fd6_2
- libglvnd=1.7.0=ha4b6fd6_2
- libglx=1.7.0=ha4b6fd6_2
- libgomp=14.2.0=h77fa898_1
- libiconv=1.17=hd590300_2
- libidn2=2.3.7=hd590300_0
- libjpeg-turbo=3.0.0=hd590300_1
- liblapack=3.9.0=16_linux64_mkl
- liblzma=5.6.3=hb9d3cd8_1
- libnpp=12.2.5.30=0
- libnsl=2.0.1=hd590300_0
- libnvfatbin=12.6.77=0
- libnvjitlink=12.4.127=0
- libnvjpeg=12.3.1.117=0
- libpciaccess=0.18=hd590300_0
- libpng=1.6.45=h943b412_0
- libsqlite=3.47.2=hee588c1_0
- libstdcxx=14.2.0=hc0a3c3a_1
- libstdcxx-ng=14.2.0=h4852527_1
- libtasn1=4.19.0=h166bdaf_0
- libtiff=4.7.0=hd9ff511_3
- libunistring=0.9.10=h7f98852_0
- libuuid=2.38.1=h0b41bf4_0
- libva=2.22.0=h8a09558_1
- libvpx=1.13.1=h59595ed_0
- libwebp=1.5.0=hae8dbeb_0
- libwebp-base=1.5.0=h851e524_0
- libxcb=1.17.0=h8a09558_0
- libxcrypt=4.4.36=hd590300_1
- libxml2=2.13.5=h0d44e9d_1
- libzlib=1.3.1=hb9d3cd8_2
- llvm-openmp=15.0.7=h0cdce71_0
- markupsafe=3.0.2=py311h2dc5d0c_1
- mkl=2022.1.0=hc2b9512_224
- mpc=1.3.1=h24ddda3_1
- mpfr=4.2.1=h90cbb55_3
- mpmath=1.3.0=pyhd8ed1ab_1
- ncurses=6.5=h2d0b736_2
- nettle=3.9.1=h7ab15ed_0
- networkx=3.4.2=pyh267e887_2
- numpy=2.2.1=py311hf916aec_0
- openh264=2.3.1=hcb278e6_2
- openjpeg=2.5.3=h5fbd93e_0
- openssl=3.4.0=h7b32b05_1
- p11-kit=0.24.1=hc5aa10d_0
- pillow=11.1.0=py311h1322bbf_0
- pip=24.3.1=pyh8b19718_2
- pthread-stubs=0.4=hb9d3cd8_1002
- pycparser=2.22=pyh29332c3_1
- pysocks=1.7.1=pyha55dd90_7
- python=3.11.11=h9e4cc4f_1_cpython
- python_abi=3.11=5_cp311
- pytorch=2.5.1=py3.11_cuda12.4_cudnn9.1.0_0
- pytorch-cuda=12.4=hc786d27_7
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.2=py311h9ecbd09_1
- readline=8.2=h8228510_1
- requests=2.32.3=pyhd8ed1ab_1
- setuptools=75.8.0=pyhff2d567_0
- svt-av1=1.4.1=hcb278e6_0
- sympy=1.13.3=pyh2585a3b_105
- tk=8.6.13=noxft_h4845f30_101
- torchaudio=2.5.1=py311_cu124
- torchtriton=3.1.0=py311
- torchvision=0.20.1=py311_cu124
- typing_extensions=4.12.2=pyha770c72_1
- tzdata=2024b=hc8b5060_0
- urllib3=2.3.0=pyhd8ed1ab_0
- wayland=1.23.1=h3e06ad9_0
- wayland-protocols=1.37=hd8ed1ab_0
- wheel=0.45.1=pyhd8ed1ab_1
- x264=1!164.3095=h166bdaf_2
- x265=3.5=h924138e_3
- xorg-libx11=1.8.10=h4f16b4b_1
- xorg-libxau=1.0.12=hb9d3cd8_0
- xorg-libxdmcp=1.1.5=hb9d3cd8_0
- xorg-libxext=1.3.6=hb9d3cd8_0
- xorg-libxfixes=6.0.1=hb9d3cd8_0
- yaml=0.2.5=h7f98852_2
- zstandard=0.23.0=py311hbc35293_1
- zstd=1.5.6=ha6fb4c9_0
- pip:
- torchmetrics>=1.4 # nicer MRR/Hits@K utilities
- lovely-tensors>=0.1.0 # tensor utilities
- lovely-numpy>=0.1.0 # numpy utilities
- einops==0.8.1
- pykeen==1.11.1
- hydra-core==1.3.2
- tensorboard==2.19.0

+ 6
- 0
tools/__init__.py View File

from .pretty_logger import get_pretty_logger
from .sampling import DeterministicRandomSampler
from .utils import set_seed
from .tb_handler import TensorBoardHandler
from .params import CommonParams, TrainingParams
from .checkpoint_manager import CheckpointManager

+ 221
- 0
tools/checkpoint_manager.py View File

from __future__ import annotations

import re
from pathlib import Path


class CheckpointManager:
"""
A checkpoint manager that handles checkpoint directory creation and path resolution.

Features:
- Creates sequentially numbered checkpoint directories
- Supports two loading modes: by components or by full path
"""

def __init__(
self,
root_directory: str | Path,
run_name: str,
load_only: bool = False,
) -> None:
"""
Initialize the checkpoint manager.

Args:
root_directory: Root directory for checkpoints
run_name: Name of the run (used as suffix for checkpoint directories)
"""
self.root_directory = Path(root_directory)
self.run_name = run_name
self.checkpoint_directory = None

if not load_only:
self.root_directory.mkdir(parents=True, exist_ok=True)
self.checkpoint_directory = self._create_checkpoint_directory()
else:
self.checkpoint_directory = ""

def _find_existing_directories(self) -> list[int]:
"""
Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number.

Returns:
List of existing sequence numbers
"""
pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$")
existing_numbers = []

if self.root_directory.exists():
for item in self.root_directory.iterdir():
if item.is_dir():
match = pattern.match(item.name)
if match:
existing_numbers.append(int(match.group(1)))

return sorted(existing_numbers)

def _create_checkpoint_directory(self) -> Path:
"""
Create a new checkpoint directory with the next sequential number.

Returns:
Path to the created checkpoint directory
"""
existing_numbers = self._find_existing_directories()

# Determine the next number
next_number = max(existing_numbers) + 1 if existing_numbers else 1

# Create directory name with 3-digit zero-padded number
dir_name = f"{next_number:03d}_{self.run_name}"
checkpoint_dir = self.root_directory / dir_name
self.run_name = dir_name

# Create the directory
checkpoint_dir.mkdir(parents=True, exist_ok=True)

return checkpoint_dir

def get_checkpoint_directory(self) -> Path:
"""
Get the current checkpoint directory.

Returns:
Path to the checkpoint directory
"""
return self.checkpoint_directory

def get_model_fpath(self, model_path: str) -> Path:
"""
Get the full path to a model checkpoint file.

Args:
model_path: Either a tuple (model_id, iteration) or a string path

Returns:
Full path to the model checkpoint file
"""

try:
model_path = eval(model_path)
if isinstance(model_path, tuple):
model_id, iteration = model_path
checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}"
filename = f"model-{iteration}.pt"
return Path(f"{checkpoint_directory}/{filename}")
except (SyntaxError, NameError):
pass

if isinstance(model_path, str):
return Path(model_path)
msg = "model_path must be a tuple (model_id, iteration) or a string path"
raise ValueError(msg)

def get_model_path_by_args(
self,
model_id: str | None = None,
iteration: int | str | None = None,
full_path: str | Path | None = None,
) -> Path:
"""
Get the path for loading a model checkpoint.

Two modes of operation:
1. Component mode: Provide model_id and iteration to construct path
2. Full path mode: Provide full_path directly

Args:
model_id: Model identifier (used in component mode)
iteration: Training iteration (used in component mode)
full_path: Full path to checkpoint (used in full path mode)

Returns:
Path to the checkpoint file

Raises:
ValueError: If neither component parameters nor full_path are provided,
or if both modes are attempted simultaneously
"""
# Check which mode we're operating in
component_mode = model_id is not None or iteration is not None
full_path_mode = full_path is not None

if component_mode and full_path_mode:
msg = (
"Cannot use both component mode (model_id/iteration) "
"and full path mode simultaneously"
)
raise ValueError(msg)

if not component_mode and not full_path_mode:
msg = "Must provide either (model_id and iteration) or full_path"
raise ValueError(msg)

if full_path_mode:
# Full path mode: return the path as-is
return Path(full_path)

# Component mode: construct path from checkpoint directory, model_id, and iteration
if model_id is None or iteration is None:
msg = "Both model_id and iteration must be provided in component mode"
raise ValueError(msg)

filename = f"{model_id}_iter_{iteration}.pt"
return self.checkpoint_directory / filename

def save_checkpoint_info(self, info: dict) -> None:
"""
Save checkpoint information to a JSON file in the checkpoint directory.

Args:
info: Dictionary containing checkpoint metadata
"""
import json

info_file = self.checkpoint_directory / "checkpoint_info.json"
with open(info_file, "w") as f:
json.dump(info, f, indent=2)

def load_checkpoint_info(self) -> dict:
"""
Load checkpoint information from the checkpoint directory.

Returns:
Dictionary containing checkpoint metadata
"""
import json

info_file = self.checkpoint_directory / "checkpoint_info.json"
if info_file.exists():
with open(info_file) as f:
return json.load(f)
return {}

def __str__(self) -> str:
return (
f"CheckpointManager(root='{self.root_directory}', "
f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')"
)

def __repr__(self) -> str:
return self.__str__()


# Example usage:
if __name__ == "__main__":
import tempfile

# Create checkpoint manager with proper temp directory
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_manager = CheckpointManager(temp_dir, "my_experiment")
print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}")

# Component mode - construct path from components
model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000)
print(f"Model path (component mode): {model_path}")

# Full path mode - use existing full path
full_path = "/some/other/path/model.pt"
model_path = ckpt_manager.get_model_path_by_args(full_path=full_path)
print(f"Model path (full path mode): {model_path}")

+ 45
- 0
tools/params.py View File

from __future__ import annotations

from dataclasses import dataclass


@dataclass
class CommonParams:
"""
Common parameters for the application.
"""

model_name: str
project_name: str
run_name: str
save_dpath: str
save_every: int = 1000
load_path: str | None = ""
log_dir: str = "./runs"
log_every: int = 10
log_console_every: int = 100
evaluate_only: bool = False


@dataclass
class TrainingParams:
"""
Parameters for training.
"""

lr: float = 0.001
min_lr: float = 0.0001
weight_decay: float = 0.0
t0: int = 50
lr_step: int = 10
gamma: float = 1.0
batch_size: int = 128
num_workers: int = 0
seed: int = 42

num_train_steps: int = 100
num_epochs: int = 100
eval_every: int = 5

validation_sample_size: int = 1000
validation_batch_size: int = 128

+ 417
- 0
tools/pretty_logger.py View File

# pretty_logger.py

import json
import logging
import logging.handlers
import sys
import threading
import time
from contextlib import ContextDecorator
from functools import wraps

# ANSI escape codes for colors
RESET = "\x1b[0m"
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = (
"\x1b[30m",
"\x1b[31m",
"\x1b[32m",
"\x1b[33m",
"\x1b[34m",
"\x1b[35m",
"\x1b[36m",
"\x1b[37m",
)

_LEVEL_COLORS = {
logging.DEBUG: CYAN,
logging.INFO: GREEN,
logging.WARNING: YELLOW,
logging.ERROR: RED,
logging.CRITICAL: MAGENTA,
}


def _supports_color() -> bool:
"""
Detect whether the running terminal supports ANSI colors.
"""
if hasattr(sys.stdout, "isatty") and sys.stdout.isatty():
platform = sys.platform
if platform == "win32":
# On Windows, modern terminals may support ANSI; assume yes for Windows 10+
return True
return True
return False


class PrettyFormatter(logging.Formatter):
"""
A logging.Formatter that outputs colorized logs to the console
(if enabled) and supports JSON formatting (if requested).
"""

def __init__(
self,
fmt: str = None,
datefmt: str = "%Y-%m-%d %H:%M:%S",
use_colors: bool = True,
json_output: bool = False,
):
super().__init__(fmt=fmt, datefmt=datefmt)
self.use_colors = use_colors and _supports_color()
self.json_output = json_output

def format(self, record: logging.LogRecord) -> str:
# If JSON output is requested, dump a dict of relevant fields
if self.json_output:
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"name": record.name,
"level": record.levelname,
"message": record.getMessage(),
}
# Include extra/contextual fields
for k, v in record.__dict__.get("extra_fields", {}).items():
log_record[k] = v
if record.exc_info:
log_record["exception"] = self.formatException(record.exc_info)
return json.dumps(log_record, ensure_ascii=False)

# Otherwise, build a colorized/“pretty” line
level_color = _LEVEL_COLORS.get(record.levelno, WHITE) if self.use_colors else ""
reset = RESET if self.use_colors else ""
timestamp = self.formatTime(record, self.datefmt)
name = record.name
level = record.levelname

# Basic format: [timestamp] [LEVEL] [name]: message
base = f"[{timestamp}] [{level}] [{name}]: {record.getMessage()}"

# If there are any extra fields attached via LoggerAdapter, append them
extra_fields = record.__dict__.get("extra_fields", {})
if extra_fields:
# key1=val1 key2=val2 …
extras_str = " ".join(f"{k}={v}" for k, v in extra_fields.items())
base = f"{base} :: {extras_str}"

# If exception info is present, include traceback
if record.exc_info:
exc_text = self.formatException(record.exc_info)
base = f"{base}\n{exc_text}"

if self.use_colors:
return f"{level_color}{base}{reset}"
else:
return base


class JsonFormatter(PrettyFormatter):
"""
Shortcut to force JSON formatting (ignores color flags).
"""

def __init__(self, datefmt: str = "%Y-%m-%d %H:%M:%S"):
super().__init__(fmt=None, datefmt=datefmt, use_colors=False, json_output=True)


class ContextFilter(logging.Filter):
"""
Inject extra_fields from a LoggerAdapter into the LogRecord, so the Formatter can find them.
"""

def filter(self, record: logging.LogRecord) -> bool:
extra = getattr(record, "extra", None)
if isinstance(extra, dict):
record.extra_fields = extra
else:
record.extra_fields = {}
return True


class Timer(ContextDecorator):
"""
Context manager & decorator to measure execution time of a code block or function.
Usage as context manager:
with Timer(logger, "block name"):
# code …
Usage as decorator:
@Timer(logger, "function foo")
def foo(...):
...
"""

def __init__(self, logger: logging.Logger, name: str = None, level: int = logging.INFO):
self.logger = logger
self.name = name or ""
self.level = level

def __enter__(self):
self.start_time = time.time()
return self

def __exit__(self, exc_type, exc, exc_tb):
elapsed = time.time() - self.start_time
label = f"[Timer:{self.name}]" if self.name else "[Timer]"
self.logger.log(self.level, f"{label} elapsed {elapsed:.4f} sec")
return False # do not suppress exceptions


def log_function(logger: logging.Logger, level: int = logging.DEBUG):
"""
Decorator factory to log function entry/exit and execution time.
Usage:
@log_function(logger, level=logging.INFO)
def my_func(...):
...
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
logger.log(level, f"➤ Entering {func.__name__}()")
start = time.time()
try:
result = func(*args, **kwargs)
except Exception:
logger.exception(f"✖ Exception in {func.__name__}()")
raise
elapsed = time.time() - start
logger.log(level, f"✔ Exiting {func.__name__}(), elapsed {elapsed:.4f} sec")
return result

return wrapper

return decorator


class PrettyLogger:
"""
Encapsulates a logging.Logger with “pretty” formatting, colorized console output,
rotating file handler, JSON option, contextual fields, etc.
"""

_instances = {}
_lock = threading.Lock()

def __new__(cls, name: str = "root", **kwargs):
"""
Implements a simple singleton: multiple calls with the same name return the same instance.
"""
with cls._lock:
if name not in cls._instances:
instance = super().__new__(cls)
cls._instances[name] = instance
return cls._instances[name]

def __init__(
self,
name: str = "root",
level: int = logging.DEBUG,
log_to_console: bool = True,
console_level: int = logging.DEBUG,
log_to_file: bool = False,
log_file: str = "app.log",
file_level: int = logging.INFO,
max_bytes: int = 10 * 1024 * 1024, # 10 MB
backup_count: int = 5,
formatter: PrettyFormatter = None,
json_output: bool = False,
):
"""
Initialize (or re‐configure) a PrettyLogger.

Args:
name: logger name.
level: root level for the logger (lowest level that will be processed).
log_to_console: whether to attach a console (StreamHandler).
console_level: level for console output.
log_to_file: whether to attach a rotating file handler.
log_file: path to the log file.
file_level: level for file output.
max_bytes: rotation threshold in bytes.
backup_count: number of backup files to keep.
formatter: an instance of PrettyFormatter. If None, a default is created.
json_output: if True, forces JSON formatting on all handlers.
"""
# Avoid re‐initializing if already done
logger = logging.getLogger(name)
if getattr(logger, "_pretty_logger_inited", False):
return

self.logger = logger
self.logger.setLevel(level)
self.logger.propagate = False # avoid double‐logging if root logger is configured elsewhere

# Create a default formatter if not supplied
if formatter is None:
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
formatter = PrettyFormatter(
fmt=fmt, use_colors=not json_output, json_output=json_output
)

# Add ContextFilter so extra_fields is always present
context_filter = ContextFilter()
self.logger.addFilter(context_filter)

# Console handler
if log_to_console:
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(console_level)
ch.setFormatter(formatter)
self.logger.addHandler(ch)

# Rotating file handler
if log_to_file:
fh = logging.handlers.RotatingFileHandler(
filename=log_file,
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8",
)
fh.setLevel(file_level)
# For file, typically we don’t colorize
file_formatter = (
JsonFormatter(datefmt="%Y-%m-%d %H:%M:%S")
if json_output
else PrettyFormatter(fmt=fmt, use_colors=False, json_output=False)
)
fh.setFormatter(file_formatter)
self.logger.addHandler(fh)

# Mark as initialized
setattr(self.logger, "_pretty_logger_inited", True)

def addHandler(self, handler: logging.Handler):
"""
Add a custom handler to the logger.
This can be used to add any logging.Handler instance.
"""
self.logger.addHandler(handler)

def bind(self, **kwargs):
"""
Return a LoggerAdapter that automatically injects extra fields into all log records.
Example:
ctx_logger = pretty_logger.bind(request_id=123, user="alice")
ctx_logger.info("Something happened") # will include request_id and user in the output
"""
return logging.LoggerAdapter(self.logger, {"extra": kwargs})

def add_stream_handler(
self, stream=None, level: int = logging.DEBUG, formatter: PrettyFormatter = None
):
"""
Add an additional stream handler (e.g. to stderr).
"""
handler = logging.StreamHandler(stream or sys.stdout)
handler.setLevel(level)
if formatter is None:
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
formatter = PrettyFormatter(fmt=fmt)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
return handler

def add_rotating_file_handler(
self,
log_file: str,
level: int = logging.INFO,
max_bytes: int = 10 * 1024 * 1024,
backup_count: int = 5,
json_output: bool = False,
):
"""
Add another rotating file handler at runtime.
"""
fh = logging.handlers.RotatingFileHandler(
filename=log_file,
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8",
)
fh.setLevel(level)
if json_output:
formatter = JsonFormatter()
else:
fmt = "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s"
formatter = PrettyFormatter(fmt=fmt, use_colors=False, json_output=False)
fh.setFormatter(formatter)
self.logger.addHandler(fh)
return fh

def remove_handler(self, handler: logging.Handler):
"""
Remove a handler from the logger.
"""
self.logger.removeHandler(handler)

def set_level(self, level: int):
"""
Change the logger’s root level at runtime.
"""
self.logger.setLevel(level)

def get_logger(self) -> logging.Logger:
return self.logger

# Convenience methods to expose common logging calls directly from this object:

def debug(self, msg, *args, **kwargs):
self.logger.debug(msg, *args, **kwargs)

def info(self, msg, *args, **kwargs):
self.logger.info(msg, *args, **kwargs)

def warning(self, msg, *args, **kwargs):
self.logger.warning(msg, *args, **kwargs)

def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)

def critical(self, msg, *args, **kwargs):
self.logger.critical(msg, *args, **kwargs)

def exception(self, msg, *args, exc_info=True, **kwargs):
"""
Log an error message with exception traceback. By default exc_info=True.
"""
self.logger.error(msg, *args, exc_info=exc_info, **kwargs)


# Convenience function for a module‐level “global” pretty logger:
_global_loggers = {}
_global_lock = threading.Lock()


def get_pretty_logger(
name: str = "root",
level: int = logging.DEBUG,
log_to_console: bool = True,
console_level: int = logging.DEBUG,
log_to_file: bool = False,
log_file: str = "app.log",
file_level: int = logging.INFO,
max_bytes: int = 10 * 1024 * 1024,
backup_count: int = 5,
json_output: bool = False,
) -> PrettyLogger:
"""
Return a PrettyLogger instance for `name`, creating/configuring it on first use.
Subsequent calls with the same `name` will return the same logger (singleton‐style).
"""
with _global_lock:
if name not in _global_loggers:
pl = PrettyLogger(
name=name,
level=level,
log_to_console=log_to_console,
console_level=console_level,
log_to_file=log_to_file,
log_file=log_file,
file_level=file_level,
max_bytes=max_bytes,
backup_count=backup_count,
json_output=json_output,
)
_global_loggers[name] = pl
return _global_loggers[name]

+ 66
- 0
tools/sampling.py View File

import torch
import numpy as np
from torch.utils.data.sampler import Sampler


def negative_sampling(
pos: torch.Tensor,
num_entities: int,
num_negatives: int = 1,
mode: list = ["head", "tail"],
) -> torch.Tensor:
pos = pos.repeat_interleave(num_negatives, dim=0)
if "head" in mode:
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
pos[:, 0] = neg
if "tail" in mode:
neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
pos[:, 2] = neg
return pos


class DeterministicRandomSampler(Sampler):
"""
A deterministic random sampler that selects a fixed number of samples
in a reproducible manner using a given seed.
"""

def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None:
"""
Args:
dataset (Dataset): PyTorch dataset.
sample_size (int, optional): Number of samples to draw. If None, use full dataset.
seed (int): Seed for reproducibility.
"""
self.dataset_size = dataset_size
self.seed = seed
self.sample_size = sample_size if sample_size != 0 else self.dataset_size

# Ensure sample size is within dataset size
if self.sample_size > self.dataset_size:
msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}."
raise ValueError(
msg,
)

self.indices = self._generate_deterministic_indices()

def _generate_deterministic_indices(self) -> list[int]:
"""
Generates a fixed random subset of indices using the given seed.
"""
rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility
all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices
return all_indices[: self.sample_size].tolist() # Select only the desired number of samples

def __iter__(self) -> iter:
"""
Yields the shuffled dataset indices.
"""
return iter(self.indices)

def __len__(self) -> int:
"""
Returns the total number of samples drawn.
"""
return self.sample_size

+ 21
- 0
tools/tb_handler.py View File

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from torch.utils.tensorboard import SummaryWriter


class TensorBoardHandler(logging.Handler):
"""Convert each log record into a TensorBoard text summary."""

def __init__(self, writer: SummaryWriter, tag: str = "logs"):
super().__init__(level=logging.INFO)
self.writer = writer
self.tag = tag
self._step = 0 # will be incremented every emit

def emit(self, record: logging.LogRecord) -> None:
self.writer.add_text(self.tag, self.format(record), global_step=self._step)
self._step += 1

+ 12
- 0
tools/train.py View File

from torch.optim.lr_scheduler import StepLR


class StepMinLR(StepLR):
def __init__(self, optimizer, step_size, gamma=0.1, min_lr=0.0, last_epoch=-1):
self.min_lr = min_lr
super().__init__(optimizer, step_size, gamma, last_epoch)

def get_lr(self):
# use StepLR's formula, then clamp
raw_lrs = super().get_lr()
return [max(lr, self.min_lr) for lr in raw_lrs]

+ 25
- 0
tools/utils.py View File

import random
from enum import Enum

import numpy as np
import torch


def set_seed(seed: int = 42) -> None:
"""
Set the random seed for reproducibility.

Args:
seed (int): The seed value to set for random number generation.
"""
torch.manual_seed(seed)
random.seed(seed)
np.random.default_rng(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


class Target(str, Enum): # ↔ PyKEEN's LABEL_* constants
HEAD = "head"
TAIL = "tail"
RELATION = "relation"

+ 0
- 0
training/__init__.py View File


+ 0
- 0
training/model_trainers/__init__.py View File


BIN
training/model_trainers/__pycache__/__init__.cpython-310.pyc View File


BIN
training/model_trainers/__pycache__/model_trainer_base.cpython-310.pyc View File


+ 85
- 0
training/model_trainers/model_trainer_base.py View File

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn

if TYPE_CHECKING:
from torch.optim.lr_scheduler import LRScheduler


class ModelTrainerBase(ABC):
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: LRScheduler | None = None,
device: torch.device = "cpu",
) -> None:
self.model = model.to(device)
self.optimizer = optimizer
self.scheduler = scheduler
self.device = device

def save(self, save_dpath: str, step: int) -> None:
"""
Save the model and optimizer state to the specified directory.

Args:
save_dpath (str): The directory path where the model and optimizer state will be saved.
"""

data = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"step": step,
}

if self.scheduler is not None:
data["scheduler"] = self.scheduler.state_dict()

torch.save(
data,
f"{save_dpath}/model-{step}.pt",
)

def load(self, load_fpath: str) -> None:
"""
Load the model and optimizer state from the specified directory.

Args:
load_dpath (str): The directory path from which the model and
optimizer state will be loaded.
"""
data = torch.load(load_fpath, map_location="cpu")
model_state_dict = data["model"]
self.model.load_state_dict(model_state_dict) # move everything to the correct device
self.model.to(self.device)
optimizer_state_dict = data["optimizer"]
self.optimizer.load_state_dict(optimizer_state_dict)
if self.scheduler is not None:
scheduler_data = data.get("scheduler", {})
self.scheduler.load_state_dict(scheduler_data)

return data["step"]

@abstractmethod
def _loss(self, batch: torch.Tensor) -> torch.Tensor:
"""
Compute the loss for a batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of data.

Returns:
torch.Tensor: The computed loss for the batch.
"""
...

@abstractmethod
def train_step(self, batch: torch.Tensor) -> float: ...

@abstractmethod
def eval_step(self, batch: torch.Tensor) -> float: ...

+ 0
- 0
training/model_trainers/translation/__init__.py View File


BIN
training/model_trainers/translation/__pycache__/__init__.cpython-310.pyc View File


BIN
training/model_trainers/translation/__pycache__/trans_e_trainer.cpython-310.pyc View File


+ 148
- 0
training/model_trainers/translation/trans_e_trainer.py View File

import torch
import torch.nn.functional as F
from einops import repeat
from pykeen.sampling import BernoulliNegativeSampler
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.instances import SLCWABatch
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from metrics.ranking import simple_ranking
from models.translation.trans_e import TransE
from tools.params import TrainingParams
from training.model_trainers.model_trainer_base import ModelTrainerBase


class TransETrainer(ModelTrainerBase):
def __init__(
self,
model: TransE,
training_params: TrainingParams,
triples_factory: torch.Tensor,
mapped_triples: torch.Tensor,
device: torch.device,
margin: float = 1.0,
n_negative: int = 1,
loss_fn: str = "margin",
**kwargs,
) -> None:
optimizer = Adam(
model.inner_model.parameters(),
lr=training_params.lr,
weight_decay=training_params.weight_decay,
)

# scheduler = None # Placeholder for scheduler, can be implemented later
scheduler = StepLR(
optimizer,
step_size=training_params.lr_step,
gamma=training_params.gamma,
)

super().__init__(model, optimizer, scheduler, device)

self.triples_factory = triples_factory

self.margin = margin
self.n_negative = n_negative
self.loss_fn = loss_fn

self.negative_sampler = BernoulliNegativeSampler(
mapped_triples=mapped_triples.to(device),
num_entities=model.num_entities,
num_relations=model.num_relations,
num_negs_per_pos=n_negative,
).to(device)

self.training_loop = SLCWATrainingLoop(
model=model.inner_model,
triples_factory=triples_factory,
optimizer=optimizer,
negative_sampler=self.negative_sampler,
lr_scheduler=scheduler,
)

def create_data_loader(
self,
triples_factory: torch.Tensor,
batch_size: int,
shuffle: bool = True,
) -> DataLoader:
triples_factory = self.triples_factory if triples_factory is None else triples_factory
return self.training_loop._create_training_data_loader(
triples_factory=triples_factory,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)

def _loss(self, pos_scores: torch.Tensor, neg_scores: torch.Tensor) -> torch.Tensor:
"""
Compute the loss based on the positive and negative scores.

Args:
pos_scores (torch.Tensor): Scores for positive triples.
neg_scores (torch.Tensor): Scores for negative triples.

Returns:
torch.Tensor: The computed loss.
"""

k = neg_scores.shape[1]

if self.loss_fn == "margin":
target = -torch.ones_like(neg_scores) # want s_pos < s_neg
loss = F.margin_ranking_loss(
repeat(pos_scores, "b -> b k", k=k),
neg_scores,
target,
margin=self.margin,
)
return loss
if self.loss_fn == "adversarial":
# b = pos_scores.size(0)
# neg_scores = rearrange(neg_scores, "(b k) -> b k", b=b, k=k)

# importance weights (detach so gradients flow only through log-sigmoid terms)
w = F.softmax(self.adv_temp * neg_scores, dim=1).detach()

pos_term = -F.logsigmoid(self.margin - pos_scores) # (B,)
neg_term = -(w * F.logsigmoid(neg_scores - self.margin)).sum(dim=1) # (B,)

loss = (pos_term + neg_term).mean()
return loss
return None

def train_step(self, batch: SLCWABatch, step: int) -> float:
"""
Perform a training step on the given batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of triples.

Returns:
float: The computed loss for the batch.
"""

continue_training = step > 0

loss = self.training_loop.train(
triples_factory=self.triples_factory,
num_epochs=step + 1,
batch_size=self.training_loop._get_batch_size(batch),
continue_training=continue_training,
use_tqdm=False,
use_tqdm_batch=False,
label_smoothing=0.0, # Assuming no label smoothing for simplicity
# num_workers=3,
)[-1]

return loss

def eval_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Return rank tensor for this batch (needed for MRR / Hits@K)."""
self.model.eval()
with torch.no_grad():
return simple_ranking(self.model, batch.to(self.device))

+ 121
- 0
training/model_trainers/translation/trans_h_trainer.py View File

import torch
from pykeen.losses import MarginRankingLoss
from pykeen.sampling import BernoulliNegativeSampler
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.instances import SLCWABatch
from torch.optim import Adam

from metrics.ranking import simple_ranking
from models.translation.trans_h import TransH
from tools.params import TrainingParams
from tools.train import StepMinLR
from training.model_trainers.model_trainer_base import ModelTrainerBase


class TransHTrainer(ModelTrainerBase):
def __init__(
self,
model: TransH,
training_params: TrainingParams,
triples_factory: torch.Tensor,
mapped_triples: torch.Tensor,
device: torch.device,
margin: float = 1.0,
n_negative: int = 1,
regul_rate: float = 0.0,
loss_fn: str = "margin",
**kwargs,
) -> None:
optimizer = Adam(
model.inner_model.parameters(),
lr=training_params.lr,
weight_decay=training_params.weight_decay,
)

scheduler = StepMinLR(
optimizer,
step_size=training_params.lr_step,
gamma=training_params.gamma,
min_lr=training_params.min_lr,
)

super().__init__(model, optimizer, scheduler, device)

self.triples_factory = triples_factory

self.margin = margin
self.n_negative = n_negative
self.regul_rate = regul_rate
self.loss_fn = loss_fn

self.negative_sampler = BernoulliNegativeSampler(
mapped_triples=mapped_triples.to(device),
num_entities=model.num_entities,
num_relations=model.num_relations,
num_negs_per_pos=n_negative,
).to(device)

self.training_loop = SLCWATrainingLoop(
model=model.inner_model,
triples_factory=triples_factory,
optimizer=optimizer,
negative_sampler=self.negative_sampler,
lr_scheduler=scheduler,
)

def _loss(self) -> torch.Tensor:
return self.model.loss

def create_data_loader(
self,
triples_factory: torch.Tensor,
batch_size: int,
shuffle: bool = True,
) -> torch.utils.data.DataLoader:
triples_factory = self.triples_factory if triples_factory is None else triples_factory
return self.training_loop._create_training_data_loader(
triples_factory=triples_factory,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)

@property
def loss(self) -> torch.Tensor:
return MarginRankingLoss(margin=self.margin, reduction="mean")

def train_step(
self,
batch: SLCWABatch,
step: int,
) -> float:
"""
Perform a training step on the given batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of triples.

Returns:
float: The computed loss for the batch.
"""

continue_training = step > 0

loss = self.training_loop.train(
triples_factory=self.triples_factory,
num_epochs=step + 1,
batch_size=self.training_loop._get_batch_size(batch),
continue_training=continue_training,
use_tqdm=False,
use_tqdm_batch=False,
label_smoothing=0.0, # Assuming no label smoothing for simplicity
)[-1]

return loss

def eval_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Return rank tensor for this batch (needed for MRR / Hits@K)."""
self.model.eval()
with torch.no_grad():
return simple_ranking(self.model, batch.to(self.device))

+ 116
- 0
training/model_trainers/translation/trans_r_trainer.py View File

import torch
from pykeen.sampling import BernoulliNegativeSampler
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.instances import SLCWABatch
from torch.optim import Adam

from metrics.ranking import simple_ranking
from models.translation.trans_r import TransR
from tools.params import TrainingParams
from tools.train import StepMinLR
from training.model_trainers.model_trainer_base import ModelTrainerBase


class TransRTrainer(ModelTrainerBase):
def __init__(
self,
model: TransR,
training_params: TrainingParams,
triples_factory: torch.Tensor,
mapped_triples: torch.Tensor,
device: torch.device,
margin: float = 1.0,
n_negative: int = 1,
regul_rate: float = 0.0,
loss_fn: str = "margin",
**kwargs,
) -> None:
optimizer = Adam(
model.inner_model.parameters(),
lr=training_params.lr,
weight_decay=training_params.weight_decay,
)

scheduler = StepMinLR(
optimizer,
step_size=training_params.lr_step,
gamma=training_params.gamma,
min_lr=training_params.min_lr,
)

super().__init__(model, optimizer, scheduler, device)

self.triples_factory = triples_factory

self.margin = margin
self.n_negative = n_negative
self.regul_rate = regul_rate
self.loss_fn = loss_fn

self.negative_sampler = BernoulliNegativeSampler(
mapped_triples=mapped_triples.to(device),
num_entities=model.num_entities,
num_relations=model.num_relations,
num_negs_per_pos=n_negative,
).to(device)

self.training_loop = SLCWATrainingLoop(
model=model.inner_model,
triples_factory=triples_factory,
optimizer=optimizer,
negative_sampler=self.negative_sampler,
lr_scheduler=scheduler,
)

def create_data_loader(
self,
triples_factory: torch.Tensor,
batch_size: int,
shuffle: bool = True,
) -> torch.utils.data.DataLoader:
triples_factory = self.triples_factory if triples_factory is None else triples_factory
return self.training_loop._create_training_data_loader(
triples_factory=triples_factory,
sampler=None,
batch_size=batch_size,
shuffle=shuffle,
drop_last=False,
)

def _loss(self) -> torch.Tensor:
return self.model.loss

def train_step(
self,
batch: SLCWABatch,
step: int,
) -> float:
"""
Perform a training step on the given batch of data.

Args:
batch (torch.Tensor): A tensor containing a batch of triples.

Returns:
float: The computed loss for the batch.
"""

continue_training = step > 0

loss = self.training_loop.train(
triples_factory=self.triples_factory,
num_epochs=step + 1,
batch_size=self.training_loop._get_batch_size(batch),
continue_training=continue_training,
use_tqdm=False,
use_tqdm_batch=False,
label_smoothing=0.0, # Assuming no label smoothing for simplicity
)[-1]

return loss

def eval_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Return rank tensor for this batch (needed for MRR / Hits@K)."""
self.model.eval()
with torch.no_grad():
return simple_ranking(self.model, batch.to(self.device))

+ 214
- 0
training/trainer.py View File

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch
from pykeen.evaluation import RankBasedEvaluator
from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from tools import (
CheckpointManager,
DeterministicRandomSampler,
TensorBoardHandler,
get_pretty_logger,
set_seed,
)

if TYPE_CHECKING:
from data.kg_dataset import KGDataset
from tools import CommonParams, TrainingParams

from .model_trainers.model_trainer_base import ModelTrainerBase


logger = get_pretty_logger(__name__)


cpu_device = torch.device("cpu")


class Trainer:
def __init__(
self,
train_dataset: KGDataset,
val_dataset: KGDataset | None,
model_trainer: ModelTrainerBase,
common_params: CommonParams,
training_params: TrainingParams,
device: torch.device = cpu_device,
) -> None:
set_seed(training_params.seed)

self.train_dataset = train_dataset
self.val_dataset = val_dataset if val_dataset is not None else train_dataset
self.model_trainer = model_trainer

self.common_params = common_params
self.training_params = training_params

self.device = device

self.train_loader = self.model_trainer.create_data_loader(
triples_factory=self.train_dataset.triples_factory,
batch_size=training_params.batch_size,
shuffle=True,
)

self.train_iterator = self._data_iterator(self.train_loader)

seed = 1234
val_sampler = DeterministicRandomSampler(
len(self.val_dataset),
sample_size=self.training_params.validation_sample_size,
seed=seed,
)
self.valid_loader = DataLoader(
self.val_dataset,
batch_size=training_params.validation_batch_size,
shuffle=False,
num_workers=training_params.num_workers,
sampler=val_sampler,
)
self.valid_iterator = self._data_iterator(self.valid_loader)

self.step = 0
self.log_every = common_params.log_every
self.log_console_every = common_params.log_console_every

self.save_dpath = common_params.save_dpath
self.save_every = common_params.save_every

self.checkpoint_manager = CheckpointManager(
root_directory=self.save_dpath,
run_name=common_params.run_name,
load_only=common_params.evaluate_only,
)
self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory
if not common_params.evaluate_only:
logger.info(f"Saving checkpoints to {self.ckpt_dpath}")

if self.common_params.load_path:
self.load()

self.writer = None
if not common_params.evaluate_only:
run_name = self.checkpoint_manager.run_name
self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}")
logger.addHandler(TensorBoardHandler(self.writer))

def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None:
"""
Log a message to the console and TensorBoard.

Args:
message (str): The message to log.
"""
if console_msg is not None:
logger.info(f"{console_msg}")
if self.writer is not None:
self.writer.add_scalar(tag, log_value, self.step)

def save(self) -> None:
save_dpath = self.ckpt_dpath

save_dpath = Path(save_dpath)
save_dpath.mkdir(parents=True, exist_ok=True)
self.model_trainer.save(save_dpath, self.step)

def load(self) -> None:
load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path)
load_fpath = Path(load_path)

if load_fpath.exists():
self.step = self.model_trainer.load(load_fpath)
logger.info(f"Model loaded from {load_fpath}.")
else:
msg = f"Load path {load_fpath} does not exist."
raise FileNotFoundError(msg)

def _data_iterator(self, loader: DataLoader) -> iter:
"""Yield batches of positive and negative triples."""
while True:
yield from loader

def train(self) -> None:
"""Run `n_steps` optimisation steps; return mean loss."""
data_batch = next(self.train_iterator)

loss = self.model_trainer.train_step(data_batch, self.step)

if self.step % self.log_console_every == 0:
logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}")
if self.step % self.log_every == 0:
self.log("train/loss", loss)

self.step += 1

def evaluate(self) -> dict[str, torch.Tensor]:
ks = [1, 3, 10]

metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks]

evaluator = RankBasedEvaluator(
metrics=metrics,
filtered=True,
batch_size=self.training_params.validation_batch_size,
)

result = evaluator.evaluate(
model=self.model_trainer.model,
mapped_triples=self.val_dataset.triples,
additional_filter_triples=[
self.train_dataset.triples,
self.val_dataset.triples,
],
batch_size=self.training_params.validation_batch_size,
)

mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank")
hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks}

metrics = {
"mrr": mrr,
**hits_at,
}

for k, v in metrics.items():
self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}")

# ───────────────────────── #
# master schedule
# ───────────────────────── #
def run(self) -> None:
"""
Alternate training and evaluation until `total_steps`
optimisation steps have been executed.
"""

total_steps = self.training_params.num_train_steps
eval_every = self.training_params.eval_every

while self.step < total_steps:
self.train()
if self.step % eval_every == 0:
logger.info(f"Evaluating at step {self.step}...")
self.evaluate()

if self.step % self.save_every == 0:
logger.info(f"Saving model at step {self.step}...")
self.save()

logger.info("Training complete.")

def reset(self) -> None:
"""
Reset the trainer state, including the model trainer and data iterators.
"""
self.model_trainer.reset()
self.train_iterator = self._data_iterator(self.train_loader)
self.valid_iterator = self._data_iterator(self.valid_loader)
self.step = 0
logger.info("Trainer state has been reset.")

Loading…
Cancel
Save