123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- from __future__ import annotations
-
- import math
- from collections import Counter, defaultdict
- from typing import TYPE_CHECKING, Dict, List, Tuple
-
- import torch
-
- from .base_metric import BaseMetric
-
- if TYPE_CHECKING:
- from data.kg_dataset import KGDataset
-
-
- def _entropy_from_counter(counter: Counter[int], n: int) -> float:
- """Return Shannon entropy (nats) of a colour distribution of length *n*."""
-
- ent = 0.0
- for cnt in counter.values():
- if cnt:
- p = cnt / n
- ent -= p * math.log(p)
- return ent
-
-
- def _normalise_scores(entropies: List[float], n: int) -> Tuple[float, float]:
- """Return (C_ratio, C_NWLEC) ∈ [0,1]² from per-layer entropies."""
-
- if n <= 1:
- return 0.0, 0.0
- H_plus_1 = len(entropies)
- log_n = math.log(n)
- c_ratio = sum(ent / log_n for ent in entropies) / H_plus_1
- c_nwlec = sum((math.exp(ent) - 1) / (n - 1) for ent in entropies) / H_plus_1
- return c_ratio, c_nwlec
-
-
- def _relation_families(
- triples: torch.Tensor, num_relations: int, *, thresh: float = 0.9
- ) -> torch.Tensor:
- """Return a LongTensor mapping each relation id → family id.
-
- Two relations *r, r_inv* are put into the same family when **at least
- `thresh` fraction** of the edges labelled *r* have a **single** reverse edge
- labelled *r_inv*.
- """
- assert 0.0 < thresh <= 1.0
-
- # Build map (h,t) -> list of relations on that edge direction.
- edge2rels: Dict[Tuple[int, int], List[int]] = defaultdict(list)
- for h, r, t in triples.tolist():
- edge2rels[(h, t)].append(int(r))
-
- # Count reciprocal co‑occurrences (r, r_rev).
- pair_counts: Dict[Tuple[int, int], int] = defaultdict(int)
- rel_totals: List[int] = [0] * num_relations
- for (h, t), rels in edge2rels.items():
- rev_rels = edge2rels.get((t, h))
- if not rev_rels:
- continue
- for r in rels:
- rel_totals[r] += 1
- for r_rev in rev_rels:
- pair_counts[(r, r_rev)] += 1
-
- # Proposed inverse mapping r -> r_inv.
- proposed: List[int | None] = [None] * num_relations
- for (r, r_rev), cnt in pair_counts.items():
- if cnt == 0:
- continue
- if cnt / rel_totals[r] >= thresh:
- # majority of r's edges are reversed by r_rev
- proposed[r] = r_rev
-
- # Build family ids (transitive closure is overkill; data are simple).
- family = list(range(num_relations))
- for r, rinv in enumerate(proposed):
- if rinv is not None:
- root = min(family[r], family[rinv])
- family[r] = family[rinv] = root
- return torch.tensor(family, dtype=torch.long)
-
-
- def _conditional_relation_entropy(
- colours: List[int],
- triples: torch.Tensor,
- family_map: torch.Tensor,
- ) -> float:
- counts: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))
- m = int(triples.size(0))
- fam = family_map # alias for speed
-
- for h, r, t in triples.tolist():
- # key = (colours[h], colours[t]) # ordered key (direction‑aware)
- key = tuple(sorted((colours[h], colours[t]))) # unordered key (direction‑agnostic)
- counts[key][int(fam[r])] += 1 # ▼ use family id
-
- h_cond = 0.0
- for rel_counts in counts.values():
- total_s = sum(rel_counts.values())
- P_s = total_s / m
- inv_total = 1.0 / total_s
- for cnt in rel_counts.values():
- p_rs = cnt * inv_total
- h_cond -= P_s * p_rs * math.log(p_rs)
- return h_cond
-
-
- class WLCREC(BaseMetric):
- """Class to compute the Weisfeiler-Lehman Entropy Complexity (WLEC) for a knowledge graph."""
-
- def __init__(self, dataset: KGDataset) -> None:
- super().__init__(dataset)
-
- def compute(
- self,
- H: int = 3,
- cond_h: int = 1,
- inv_thresh: float = 0.9,
- return_full: bool = False,
- progress: bool = True,
- ) -> float | Tuple[float, List[float]]:
- """Compute WL-entropy scores and the composite difficulty metric.
-
- Parameters
- ----------
- dataset
- Any object exposing ``.triples``, ``.num_entities``, ``.num_relations``.
- H
- WL refinement depth (*≥0*). Entropy is recorded for **H+1** layers.
- return_full
- Also return the list of per-layer entropies.
- progress
- Print textual progress bar.
-
- Returns
- -------
- avg_entropy, C_ratio, C_NWLEC, H_cond, D_ratio, D_NWLEC
- *Six* scalars (floats). If ``return_full=True`` an extra list of
- layer entropies is appended.
- """
-
- # compute relation family map once
- family_map = _relation_families(
- self.dataset.triples, self.dataset.num_relations, thresh=inv_thresh
- )
-
- # ─── WL iterations ────────────────────────────────────────────────────
-
- n = self.dataset.num_entities
- colour: List[int] = [1] * n # colour 1 for everyone
- entropies: List[float] = []
- colours_per_h: List[List[int]] = []
-
- def _print_prog(i: int) -> None:
- if progress:
- print(f"\rWL iteration {i}/{H}", end="", flush=True)
-
- for h in range(H + 1):
- _print_prog(h)
-
- colours_per_h.append(colour.copy())
-
- # entropy of current colouring
- freq = Counter(colour)
- entropies.append(_entropy_from_counter(freq, n))
-
- if h == H:
- break
-
- # refine colours
- bucket: Dict[Tuple[int, Tuple[Tuple[int, int, int], ...]], int] = {}
- next_colour: List[int] = [0] * n
-
- for v in range(n):
- T: List[Tuple[int, int, int]] = []
- for r, t in self.out_adj[v]:
- T.append((r, 0, colour[t])) # outgoing
- for r, h_ in self.in_adj[v]:
- T.append((r, 1, colour[h_])) # incoming
- T.sort()
- key = (colour[v], tuple(T))
- if key not in bucket:
- bucket[key] = len(bucket) + 1
- next_colour[v] = bucket[key]
-
- colour = next_colour
-
- _print_prog(H)
- if progress:
- print()
-
- # ─── Normalised diversity ────────────────────────────────────────────
- avg_entropy = sum(entropies) / len(entropies)
- c_ratio, c_nwlec = _normalise_scores(entropies, n)
-
- # ─── Residual relation entropy ───────────────────────────────────────
- h_cond = _conditional_relation_entropy(
- colours_per_h[cond_h], self.dataset.triples, family_map
- )
-
- # ─── Composite difficulty - Eq. (1) ──────────────────────────────────
- d_ratio = c_ratio * h_cond
- d_nwlec = c_nwlec * h_cond
-
- result = (avg_entropy, c_ratio, c_nwlec, h_cond, d_ratio, d_nwlec)
- if return_full:
- result += (entropies,)
- return result
-
- # --------------------------------------------------------------------
- # Weisfeiler–Lehman colours (unchanged)
- # --------------------------------------------------------------------
- def wl_colours(self, depth: int) -> List[List[int]]:
- triples = self.dataset.triples.tolist()
- out_adj, in_adj = defaultdict(list), defaultdict(list)
- for h, r, t in triples:
- out_adj[h].append((r, t))
- in_adj[t].append((r, h))
-
- n_ent = self.dataset.num_entities
- colours = [[0] * n_ent] # round-0
- for h in range(1, depth + 1):
- prev, nxt, sig2c = colours[-1], [0] * n_ent, {}
- fresh = 0
- for v in range(n_ent):
- neigh = [("↓", r, prev[u]) for r, u in out_adj.get(v, [])] + [
- ("↑", r, prev[u]) for r, u in in_adj.get(v, [])
- ]
- neigh.sort()
- sig = (prev[v], tuple(neigh))
- if sig not in sig2c:
- sig2c[sig] = fresh
- fresh += 1
- nxt[v] = sig2c[sig]
- colours.append(nxt)
- return colours
|