You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_metric.py 1.6KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import TYPE_CHECKING, Any, List, Tuple
  4. if TYPE_CHECKING:
  5. from data.kg_dataset import KGDataset
  6. class BaseMetric(ABC):
  7. """Base class for metrics that compute Weisfeiler-Lehman Entropy
  8. Complexity (WLEC) or similar metrics.
  9. """
  10. def __init__(self, dataset: KGDataset) -> None:
  11. self.dataset = dataset
  12. self.num_entities = dataset.num_entities
  13. self.out_adj, self.in_adj = self._build_adjacency_lists()
  14. def _build_adjacency_lists(
  15. self,
  16. ) -> Tuple[List[List[Tuple[int, int]]], List[List[Tuple[int, int]]]]:
  17. """Create *in* and *out* adjacency lists from mapped triples.
  18. Parameters
  19. ----------
  20. triples:
  21. Tensor of shape *(m, 3)* containing *(h, r, t)* triples (``dtype=torch.long``).
  22. num_entities:
  23. Total number of entities. Needed to allocate adjacency containers.
  24. Returns
  25. -------
  26. out_adj, in_adj
  27. Each element ``out_adj[v]`` is a list ``[(r, t), ...]``. Each element
  28. ``in_adj[v]`` is a list ``[(r, h), ...]``.
  29. """
  30. triples = self.dataset.triples # (m, 3)
  31. num_entities = self.num_entities
  32. out_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)]
  33. in_adj: List[List[Tuple[int, int]]] = [[] for _ in range(num_entities)]
  34. for h, r, t in triples.tolist():
  35. out_adj[h].append((r, t))
  36. in_adj[t].append((r, h))
  37. return out_adj, in_adj
  38. @abstractmethod
  39. def compute(
  40. self,
  41. *args,
  42. **kwargs,
  43. ) -> Any: ...