from __future__ import annotations import math from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Iterable import torch from torch import Tensor from .base_metric import BaseMetric if TYPE_CHECKING: from data.kg_dataset import KGDataset from models.base_model import ERModel @torch.no_grad() def _log_prob_distribution( fn: Callable[[Tensor], Tensor], query: Tensor, batch_size: int, ) -> Iterable[Tensor]: """ Yield *log*-probability batches for an arbitrary distribution helper. Parameters ---------- fn : A bound method of *model* returning **log**-probabilities, e.g. ``model.tail_distribution(..., log=True)``. query : LongTensor of shape (N, 2) containing the query IDs that `fn` expects. batch_size : Mini-batch size. Choose according to available GPU/CPU memory. """ for i in range(0, len(query), batch_size): yield fn(query[i : i + batch_size]) # (B, |Candidates|) def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor: """ KL(q || p) where q is uniform over *true_index_sets*. Parameters ---------- log_p : Tensor of shape (B, N) — log-probabilities from the model. true_index_sets : List (length B). Element *b* contains the list of **indices** that are correct for row *b*. Returns ------- kl : Tensor, shape (B,) — KL divergence per row (natural log). """ rows: list[Tensor] = [] for lp_row, idx in zip(log_p, true_index_sets): k = len(idx) rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p] return torch.tensor(rows, device=log_p.device) @dataclass(slots=True) class _Accum: """Simple accumulator for ∑ value and ∑ weight.""" tot_v: float = 0.0 tot_w: float = 0.0 def update(self, value: float, weight: float) -> None: self.tot_v += value * weight self.tot_w += weight @property def mean(self) -> float: return self.tot_v / self.tot_w if self.tot_w else float("nan") class CSWKLF(BaseMetric): """ C-SWKLF metric for evaluating knowledge graph embeddings. This metric is used to evaluate the quality of knowledge graph embeddings by computing the C-SWKLF score. """ def __init__(self, dataset: KGDataset, model: ERModel) -> None: super().__init__(dataset) self.model = model @torch.no_grad() def compute( self, alpha: float = 1 / 3, beta: float = 1 / 3, gamma: float = 1 / 3, batch_size: int = 1024, slice_size: int | None = 2048, device: torch.device | str | None = None, ) -> float: """ Compute *Comprehensive Structural-Weighted KL-Fitness*. Parameters ---------- model : A trained PyKEEN ERModel **with** the three distribution helpers. kg : `pykeen.triples.KGInfo` instance providing `mapped_triples`. alpha, beta, gamma : Non-negative weights for the three query types, default 1/3 each. Must sum to *1*. batch_size : Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM. slice_size : Forwarded to `tail_distribution` / `head_distribution` / `relation_distribution`. `None` disables slicing. device : Where to do the computation. `None` ⇒ first param's device. Returns ------- score : float in (0, 1] Higher = model closer to empirical distributions across *all* directions. """ # --------------------------------------------------------------------- # # Preparation # # --------------------------------------------------------------------- # assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1." model_device = next(self.model.parameters()).device if device is None: device = model_device triples = self.dataset.triples.to(device) # (|T|, 3) heads, rels, tails = triples.t() # (|T|,) # Build index structures -------------------------------------------------- tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list) heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list) rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list) for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()): tails_by_hr[(h, r)].append(t) heads_by_rt[(r, t)].append(h) rels_by_ht[(h, t)].append(r) # Structural entropies & query lists --------------------------------------- H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r H_head: dict[int, _Accum] = defaultdict(_Accum) tail_queries: list[tuple[int, int]] = [] # (h,r) head_queries: list[tuple[int, int]] = [] # (r,t) rel_queries: list[tuple[int, int]] = [] # (h,t) # --- tails -------------------------------------------------------------- for (h, r), ts in tails_by_hr.items(): k = len(ts) H_tail[r].update(math.log(k), 1.0) tail_queries.append((h, r)) # --- heads -------------------------------------------------------------- for (r, t), hs in heads_by_rt.items(): k = len(hs) H_head[r].update(math.log(k), 1.0) head_queries.append((r, t)) # --- relations ---------------------------------------------------------- H_rel_accum = _Accum() for (h, t), rs in rels_by_ht.items(): k = len(rs) H_rel_accum.update(math.log(k), 1.0) rel_queries.append((h, t)) H_struct_tail = {r: acc.mean for r, acc in H_tail.items()} H_struct_head = {r: acc.mean for r, acc in H_head.items()} # H_struct_rel = H_rel_accum.mean # scalar # --------------------------------------------------------------------- # # KL divergences # # --------------------------------------------------------------------- # def _avg_kl_per_relation( queries: list[tuple[int, int]], true_idx_map: dict[tuple[int, int], list[int]], distribution_fn: Callable[[Tensor], Tensor], num_relations: bool = False, # switch to know output size ) -> dict[int, float]: """ Compute *average* KL(q || p) **per relation**. Works for the tail & head conditionals. """ kl_sum: dict[int, float] = defaultdict(float) count: dict[int, int] = defaultdict(int) query_tensor = torch.tensor(queries, dtype=torch.long, device=device) for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size): # slice_size handled inside distribution_fn # log_p : (B, |Candidates|) b = log_p.shape[0] # Map back to current slice of queries slice_queries = queries[:b] queries[:] = queries[b:] # consume true_sets = [true_idx_map[q] for q in slice_queries] kl_row = _kl_uniform(log_p, true_sets) # (B,) for (x, r), kl_val in zip(slice_queries, kl_row.tolist()): rel_id = r if not num_relations else x kl_sum[rel_id] += kl_val count[rel_id] += 1 return {r: kl_sum[r] / count[r] for r in kl_sum} # --- tails -------------------------------------------------------------- tail_queries_copy = tail_queries.copy() D_tail = _avg_kl_per_relation( tail_queries_copy, tails_by_hr, lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s), ) S_tail = {r: math.exp(-d) for r, d in D_tail.items()} # --- heads -------------------------------------------------------------- head_queries_copy = head_queries.copy() D_head = _avg_kl_per_relation( head_queries_copy, heads_by_rt, lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s), num_relations=True, ) S_head = {r: math.exp(-d) for r, d in D_head.items()} # --- relations (single global number) ---------------------------------- D_rel_sum = 0.0 for log_p in _log_prob_distribution( lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s), torch.tensor(rel_queries, dtype=torch.long, device=device), batch_size, ): slice_size_batch = log_p.shape[0] slice_queries = rel_queries[:slice_size_batch] rel_queries[:] = rel_queries[slice_size_batch:] true_sets = [rels_by_ht[q] for q in slice_queries] D_rel_sum += _kl_uniform(log_p, true_sets).sum().item() D_rel = D_rel_sum / H_rel_accum.tot_w S_rel = math.exp(-D_rel) # --------------------------------------------------------------------- # # Weighted aggregation → C-SWKLF # # --------------------------------------------------------------------- # def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float: num = sum(weight[r] * score[r] for r in score) den = sum(weight[r] for r in score) return num / den if den else float("nan") sw_tail = _weighted_mean(S_tail, H_struct_tail) sw_head = _weighted_mean(S_head, H_struct_head) # Relations: numerator & denominator cancel sw_rel = S_rel return alpha * sw_tail + beta * sw_head + gamma * sw_rel