from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING import torch from torch import nn from torch.nn.functional import log_softmax, softmax if TYPE_CHECKING: from pykeen.typing import FloatTensor, InductiveMode, LongTensor, MappedTriples, Target class ERModel(nn.Module, ABC): """Base class for knowledge graph models.""" def __init__(self, num_entities: int, num_relations: int, dim: int) -> None: super().__init__() self.num_entities = num_entities self.num_relations = num_relations self.dim = dim self.inner_model = None @property def device(self) -> torch.device: """Return the device of the model.""" return next(self.inner_model.parameters()).device @abstractmethod def reset_parameters(self, *args, **kwargs) -> None: """Reset the parameters of the model.""" def predict( self, hrt_batch: MappedTriples, target: Target, full_batch: bool = True, ids: LongTensor | None = None, **kwargs, ) -> FloatTensor: if not self.inner_model: msg = ( "Inner model is not set. Please initialize the inner model before calling predict." ) raise ValueError( msg, ) return self.inner_model.predict( hrt_batch=hrt_batch, target=target, full_batch=full_batch, ids=ids, **kwargs, ) def forward( self, triples: MappedTriples, slice_size: int | None = None, slice_dim: int = 0, *, mode: InductiveMode | None, ) -> FloatTensor: h_indices = triples[:, 0] r_indices = triples[:, 1] t_indices = triples[:, 2] if not self.inner_model: msg = ( "Inner model is not set. Please initialize the inner model before calling forward." ) raise ValueError( msg, ) return self.inner_model.forward( h_indices=h_indices, r_indices=r_indices, t_indices=t_indices, slice_size=slice_size, slice_dim=slice_dim, mode=mode, ) @torch.inference_mode() def tail_distribution( # (0) TAIL-given-(head, relation) → pθ(t | h,r) self, hr_batch: LongTensor, # (B, 2) : (h,r) *, slice_size: int | None = None, mode: InductiveMode | None = None, log: bool = False, ) -> FloatTensor: # (B, |E|) scores = self.inner_model.score_t(hr_batch=hr_batch, slice_size=slice_size, mode=mode) return log_softmax(scores, -1) if log else softmax(scores, -1) # ----------------------------------------------------------------------- # # (1) HEAD-given-(relation, tail) → pθ(h | r,t) # # ----------------------------------------------------------------------- # @torch.inference_mode() def head_distribution( self, rt_batch: LongTensor, # (B, 2) : (r,t) *, slice_size: int | None = None, mode: InductiveMode | None = None, log: bool = False, ) -> FloatTensor: # (B, |E|) scores = self.inner_model.score_h(rt_batch=rt_batch, slice_size=slice_size, mode=mode) return log_softmax(scores, -1) if log else softmax(scores, -1) # ----------------------------------------------------------------------- # # (2) RELATION-given-(head, tail) → pθ(r | h,t) # # ----------------------------------------------------------------------- # @torch.inference_mode() def relation_distribution( self, ht_batch: LongTensor, # (B, 2) : (h,t) *, slice_size: int | None = None, mode: InductiveMode | None = None, log: bool = False, ) -> FloatTensor: # (B, |R|) scores = self.inner_model.score_r(ht_batch=ht_batch, slice_size=slice_size, mode=mode) return log_softmax(scores, -1) if log else softmax(scores, -1)