123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- from __future__ import annotations
-
- from abc import ABC, abstractmethod
- from typing import TYPE_CHECKING
-
- import torch
- from torch import nn
- from torch.nn.functional import log_softmax, softmax
-
- if TYPE_CHECKING:
- from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target
-
-
- class ERModel(nn.Module, ABC):
- """Base class for knowledge graph models."""
-
- def __init__(self, num_entities: int, num_relations: int, dim: int) -> None:
- super().__init__()
- self.num_entities = num_entities
- self.num_relations = num_relations
- self.dim = dim
-
- self.inner_model = None
-
- @property
- def device(self) -> torch.device:
- """Return the device of the model."""
- return next(self.inner_model.parameters()).device
-
- @abstractmethod
- def reset_parameters(self, *args, **kwargs) -> None:
- """Reset the parameters of the model."""
-
- def predict(
- self,
- hrt_batch: MappedTriples,
- target: Target,
- full_batch: bool = True,
- ids: LongTensor | None = None,
- **kwargs,
- ) -> FloatTensor:
- if not self.inner_model:
- msg = (
- "Inner model is not set. Please initialize the inner model before calling predict."
- )
- raise ValueError(
- msg,
- )
-
- return self.inner_model.predict(
- hrt_batch=hrt_batch,
- target=target,
- full_batch=full_batch,
- ids=ids,
- **kwargs,
- )
-
- def forward(
- self,
- triples: MappedTriples,
- slice_size: int | None = None,
- slice_dim: int = 0,
- *,
- mode: InductiveMode | None,
- ) -> FloatTensor:
- h_indices = triples[:, 0]
- r_indices = triples[:, 1]
- t_indices = triples[:, 2]
-
- if not self.inner_model:
- msg = (
- "Inner model is not set. Please initialize the inner model before calling forward."
- )
- raise ValueError(
- msg,
- )
-
- return self.inner_model.forward(
- h_indices=h_indices,
- r_indices=r_indices,
- t_indices=t_indices,
- slice_size=slice_size,
- slice_dim=slice_dim,
- mode=mode,
- )
-
- @torch.inference_mode()
- def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r)
- self,
- hr_batch: LongTensor, # (B, 2) : (h,r)
- *,
- slice_size: int | None = None,
- mode: InductiveMode | None = None,
- log: bool = False,
- ) -> FloatTensor: # (B, |E|)
- scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode)
- return log_softmax(scores, -1) if log else softmax(scores, -1)
-
- # ----------------------------------------------------------------------- #
- # (1) HEAD-given-(relation, tail) → pθ(h | r,t) #
- # ----------------------------------------------------------------------- #
- @torch.inference_mode()
- def head_distribution(
- self,
- rt_batch: LongTensor, # (B, 2) : (r,t)
- *,
- slice_size: int | None = None,
- mode: InductiveMode | None = None,
- log: bool = False,
- ) -> FloatTensor: # (B, |E|)
- scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode)
- return log_softmax(scores, -1) if log else softmax(scores, -1)
-
- # ----------------------------------------------------------------------- #
- # (2) RELATION-given-(head, tail) → pθ(r | h,t) #
- # ----------------------------------------------------------------------- #
- @torch.inference_mode()
- def relation_distribution(
- self,
- ht_batch: LongTensor, # (B, 2) : (h,t)
- *,
- slice_size: int | None = None,
- mode: InductiveMode | None = None,
- log: bool = False,
- ) -> FloatTensor: # (B, |R|)
- scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode)
- return log_softmax(scores, -1) if log else softmax(scores, -1)
|