123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- from __future__ import annotations
-
- import math
- from collections import defaultdict
- from dataclasses import dataclass
- from typing import TYPE_CHECKING, Callable, Iterable
-
- import torch
- from torch import Tensor
-
- from .base_metric import BaseMetric
-
- if TYPE_CHECKING:
- from data.kg_dataset import KGDataset
- from models.base_model import ERModel
-
-
- @torch.no_grad()
- def _log_prob_distribution(
- fn: Callable[[Tensor], Tensor],
- query: Tensor,
- batch_size: int,
- ) -> Iterable[Tensor]:
- """
- Yield *log*-probability batches for an arbitrary distribution helper.
-
- Parameters
- ----------
- fn :
- A bound method of *model* returning **log**-probabilities, e.g.
- ``model.tail_distribution(..., log=True)``.
- query :
- LongTensor of shape (N, 2) containing the query IDs that `fn` expects.
- batch_size :
- Mini-batch size. Choose according to available GPU/CPU memory.
- """
- for i in range(0, len(query), batch_size):
- yield fn(query[i : i + batch_size]) # (B, |Candidates|)
-
-
- def _kl_uniform(log_p: Tensor, true_index_sets: list[list[int]]) -> Tensor:
- """
- KL(q || p) where q is uniform over *true_index_sets*.
-
- Parameters
- ----------
- log_p :
- Tensor of shape (B, N) — log-probabilities from the model.
- true_index_sets :
- List (length B). Element *b* contains the list of **indices**
- that are correct for row *b*.
-
- Returns
- -------
- kl : Tensor, shape (B,) — KL divergence per row (natural log).
- """
- rows: list[Tensor] = []
- for lp_row, idx in zip(log_p, true_index_sets):
- k = len(idx)
- rows.append(math.log(k) - lp_row[idx].mean()) # log k - E_q[log p]
- return torch.tensor(rows, device=log_p.device)
-
-
- @dataclass(slots=True)
- class _Accum:
- """Simple accumulator for ∑ value and ∑ weight."""
-
- tot_v: float = 0.0
- tot_w: float = 0.0
-
- def update(self, value: float, weight: float) -> None:
- self.tot_v += value * weight
- self.tot_w += weight
-
- @property
- def mean(self) -> float:
- return self.tot_v / self.tot_w if self.tot_w else float("nan")
-
-
- class CSWKLF(BaseMetric):
- """
- C-SWKLF metric for evaluating knowledge graph embeddings.
- This metric is used to evaluate the quality of knowledge graph embeddings
- by computing the C-SWKLF score.
- """
-
- def __init__(self, dataset: KGDataset, model: ERModel) -> None:
- super().__init__(dataset)
- self.model = model
-
- @torch.no_grad()
- def compute(
- self,
- alpha: float = 1 / 3,
- beta: float = 1 / 3,
- gamma: float = 1 / 3,
- batch_size: int = 1024,
- slice_size: int | None = 2048,
- device: torch.device | str | None = None,
- ) -> float:
- """
- Compute *Comprehensive Structural-Weighted KL-Fitness*.
-
- Parameters
- ----------
- model :
- A trained PyKEEN ERModel **with** the three distribution helpers.
- kg :
- `pykeen.triples.KGInfo` instance providing `mapped_triples`.
- alpha, beta, gamma :
- Non-negative weights for the three query types, default 1/3 each.
- Must sum to *1*.
- batch_size :
- Number of queries scored per forward pass. Tune wrt. GPU/CPU RAM.
- slice_size :
- Forwarded to `tail_distribution` / `head_distribution` /
- `relation_distribution`. `None` disables slicing.
- device :
- Where to do the computation. `None` ⇒ first param's device.
-
- Returns
- -------
- score : float in (0, 1]
- Higher = model closer to empirical distributions across *all* directions.
- """
- # --------------------------------------------------------------------- #
- # Preparation #
- # --------------------------------------------------------------------- #
- assert abs(alpha + beta + gamma - 1.0) < 1e-9, "α+β+γ must be 1."
-
- model_device = next(self.model.parameters()).device
- if device is None:
- device = model_device
- triples = self.dataset.triples.to(device) # (|T|, 3)
-
- heads, rels, tails = triples.t() # (|T|,)
-
- # Build index structures --------------------------------------------------
- tails_by_hr: dict[tuple[int, int], list[int]] = defaultdict(list)
- heads_by_rt: dict[tuple[int, int], list[int]] = defaultdict(list)
- rels_by_ht: dict[tuple[int, int], list[int]] = defaultdict(list)
-
- for h, r, t in zip(heads.tolist(), rels.tolist(), tails.tolist()):
- tails_by_hr[(h, r)].append(t)
- heads_by_rt[(r, t)].append(h)
- rels_by_ht[(h, t)].append(r)
-
- # Structural entropies & query lists ---------------------------------------
- H_tail: dict[int, _Accum] = defaultdict(_Accum) # per relation r
- H_head: dict[int, _Accum] = defaultdict(_Accum)
- tail_queries: list[tuple[int, int]] = [] # (h,r)
- head_queries: list[tuple[int, int]] = [] # (r,t)
- rel_queries: list[tuple[int, int]] = [] # (h,t)
-
- # --- tails --------------------------------------------------------------
- for (h, r), ts in tails_by_hr.items():
- k = len(ts)
- H_tail[r].update(math.log(k), 1.0)
- tail_queries.append((h, r))
-
- # --- heads --------------------------------------------------------------
- for (r, t), hs in heads_by_rt.items():
- k = len(hs)
- H_head[r].update(math.log(k), 1.0)
- head_queries.append((r, t))
-
- # --- relations ----------------------------------------------------------
- H_rel_accum = _Accum()
- for (h, t), rs in rels_by_ht.items():
- k = len(rs)
- H_rel_accum.update(math.log(k), 1.0)
- rel_queries.append((h, t))
-
- H_struct_tail = {r: acc.mean for r, acc in H_tail.items()}
- H_struct_head = {r: acc.mean for r, acc in H_head.items()}
- # H_struct_rel = H_rel_accum.mean # scalar
-
- # --------------------------------------------------------------------- #
- # KL divergences #
- # --------------------------------------------------------------------- #
- def _avg_kl_per_relation(
- queries: list[tuple[int, int]],
- true_idx_map: dict[tuple[int, int], list[int]],
- distribution_fn: Callable[[Tensor], Tensor],
- num_relations: bool = False, # switch to know output size
- ) -> dict[int, float]:
- """
- Compute *average* KL(q || p) **per relation**.
-
- Works for the tail & head conditionals.
- """
- kl_sum: dict[int, float] = defaultdict(float)
- count: dict[int, int] = defaultdict(int)
-
- query_tensor = torch.tensor(queries, dtype=torch.long, device=device)
-
- for log_p in _log_prob_distribution(distribution_fn, query_tensor, batch_size):
- # slice_size handled inside distribution_fn
- # log_p : (B, |Candidates|)
- b = log_p.shape[0]
- # Map back to current slice of queries
- slice_queries = queries[:b]
- queries[:] = queries[b:] # consume
-
- true_sets = [true_idx_map[q] for q in slice_queries]
- kl_row = _kl_uniform(log_p, true_sets) # (B,)
-
- for (x, r), kl_val in zip(slice_queries, kl_row.tolist()):
- rel_id = r if not num_relations else x
- kl_sum[rel_id] += kl_val
- count[rel_id] += 1
-
- return {r: kl_sum[r] / count[r] for r in kl_sum}
-
- # --- tails --------------------------------------------------------------
- tail_queries_copy = tail_queries.copy()
- D_tail = _avg_kl_per_relation(
- tail_queries_copy,
- tails_by_hr,
- lambda q, s=slice_size: self.model.tail_distribution(q, log=True, slice_size=s),
- )
- S_tail = {r: math.exp(-d) for r, d in D_tail.items()}
-
- # --- heads --------------------------------------------------------------
- head_queries_copy = head_queries.copy()
- D_head = _avg_kl_per_relation(
- head_queries_copy,
- heads_by_rt,
- lambda q, s=slice_size: self.model.head_distribution(q, log=True, slice_size=s),
- num_relations=True,
- )
- S_head = {r: math.exp(-d) for r, d in D_head.items()}
-
- # --- relations (single global number) ----------------------------------
- D_rel_sum = 0.0
- for log_p in _log_prob_distribution(
- lambda q, s=slice_size: self.model.relation_distribution(q, log=True, slice_size=s),
- torch.tensor(rel_queries, dtype=torch.long, device=device),
- batch_size,
- ):
- slice_size_batch = log_p.shape[0]
- slice_queries = rel_queries[:slice_size_batch]
- rel_queries[:] = rel_queries[slice_size_batch:]
-
- true_sets = [rels_by_ht[q] for q in slice_queries]
- D_rel_sum += _kl_uniform(log_p, true_sets).sum().item()
-
- D_rel = D_rel_sum / H_rel_accum.tot_w
- S_rel = math.exp(-D_rel)
-
- # --------------------------------------------------------------------- #
- # Weighted aggregation → C-SWKLF #
- # --------------------------------------------------------------------- #
- def _weighted_mean(score: dict[int, float], weight: dict[int, float]) -> float:
- num = sum(weight[r] * score[r] for r in score)
- den = sum(weight[r] for r in score)
- return num / den if den else float("nan")
-
- sw_tail = _weighted_mean(S_tail, H_struct_tail)
- sw_head = _weighted_mean(S_head, H_struct_head)
- # Relations: numerator & denominator cancel
- sw_rel = S_rel
-
- return alpha * sw_tail + beta * sw_head + gamma * sw_rel
|