You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_model.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import TYPE_CHECKING
  4. import torch
  5. from torch import nn
  6. from torch.nn.functional import log_softmax, softmax
  7. if TYPE_CHECKING:
  8. from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target
  9. class ERModel(nn.Module, ABC):
  10. """Base class for knowledge graph models."""
  11. def __init__(self, num_entities: int, num_relations: int, dim: int) -> None:
  12. super().__init__()
  13. self.num_entities = num_entities
  14. self.num_relations = num_relations
  15. self.dim = dim
  16. self.inner_model = None
  17. @property
  18. def device(self) -> torch.device:
  19. """Return the device of the model."""
  20. return next(self.inner_model.parameters()).device
  21. @abstractmethod
  22. def reset_parameters(self, *args, **kwargs) -> None:
  23. """Reset the parameters of the model."""
  24. def predict(
  25. self,
  26. hrt_batch: MappedTriples,
  27. target: Target,
  28. full_batch: bool = True,
  29. ids: LongTensor | None = None,
  30. **kwargs,
  31. ) -> FloatTensor:
  32. if not self.inner_model:
  33. msg = (
  34. "Inner model is not set. Please initialize the inner model before calling predict."
  35. )
  36. raise ValueError(
  37. msg,
  38. )
  39. return self.inner_model.predict(
  40. hrt_batch=hrt_batch,
  41. target=target,
  42. full_batch=full_batch,
  43. ids=ids,
  44. **kwargs,
  45. )
  46. def forward(
  47. self,
  48. triples: MappedTriples,
  49. slice_size: int | None = None,
  50. slice_dim: int = 0,
  51. *,
  52. mode: InductiveMode | None,
  53. ) -> FloatTensor:
  54. h_indices = triples[:, 0]
  55. r_indices = triples[:, 1]
  56. t_indices = triples[:, 2]
  57. if not self.inner_model:
  58. msg = (
  59. "Inner model is not set. Please initialize the inner model before calling forward."
  60. )
  61. raise ValueError(
  62. msg,
  63. )
  64. return self.inner_model.forward(
  65. h_indices=h_indices,
  66. r_indices=r_indices,
  67. t_indices=t_indices,
  68. slice_size=slice_size,
  69. slice_dim=slice_dim,
  70. mode=mode,
  71. )
  72. @torch.inference_mode()
  73. def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r)
  74. self,
  75. hr_batch: LongTensor, # (B, 2) : (h,r)
  76. *,
  77. slice_size: int | None = None,
  78. mode: InductiveMode | None = None,
  79. log: bool = False,
  80. ) -> FloatTensor: # (B, |E|)
  81. scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode)
  82. return log_softmax(scores, -1) if log else softmax(scores, -1)
  83. # ----------------------------------------------------------------------- #
  84. # (1) HEAD-given-(relation, tail) → pθ(h | r,t) #
  85. # ----------------------------------------------------------------------- #
  86. @torch.inference_mode()
  87. def head_distribution(
  88. self,
  89. rt_batch: LongTensor, # (B, 2) : (r,t)
  90. *,
  91. slice_size: int | None = None,
  92. mode: InductiveMode | None = None,
  93. log: bool = False,
  94. ) -> FloatTensor: # (B, |E|)
  95. scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode)
  96. return log_softmax(scores, -1) if log else softmax(scores, -1)
  97. # ----------------------------------------------------------------------- #
  98. # (2) RELATION-given-(head, tail) → pθ(r | h,t) #
  99. # ----------------------------------------------------------------------- #
  100. @torch.inference_mode()
  101. def relation_distribution(
  102. self,
  103. ht_batch: LongTensor, # (B, 2) : (h,t)
  104. *,
  105. slice_size: int | None = None,
  106. mode: InductiveMode | None = None,
  107. log: bool = False,
  108. ) -> FloatTensor: # (B, |R|)
  109. scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode)
  110. return log_softmax(scores, -1) if log else softmax(scores, -1)