12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- from __future__ import annotations
-
- from typing import TYPE_CHECKING
-
- import torch
- from pykeen.losses import MarginRankingLoss
- from pykeen.models import TransR as PyKEENTransR
- from torch import nn
-
- from models.base_model import ERModel
-
- if TYPE_CHECKING:
- from pykeen.typing import MappedTriples
-
-
- class TransR(ERModel):
- """
- TransR model for knowledge graph embedding.
- """
-
- def __init__(
- self,
- num_entities: int,
- num_relations: int,
- triples_factory: MappedTriples,
- entity_dim: int = 200,
- relation_dim: int = 30,
- p_norm: bool = True,
- p_norm_value: int = 2,
- margin: float | None = None,
- epsilon: float | None = None,
- device: torch.device = torch.device("cpu"),
- **kwargs,
- ) -> None:
- super().__init__(num_entities, num_relations, entity_dim)
- self.entity_dim = entity_dim
- self.relation_dim = relation_dim
- self.p_norm = p_norm
- self.margin = margin
- self.epsilon = epsilon
-
- self.inner_model = PyKEENTransR(
- triples_factory=triples_factory,
- embedding_dim=entity_dim,
- relation_dim=relation_dim,
- scoring_fct_norm=1,
- loss=self.loss,
- power_norm=p_norm,
- entity_initializer="xavier_uniform_",
- relation_initializer="xavier_uniform_",
- random_seed=42,
- ).to(device)
-
- @property
- def loss(self) -> torch.Tensor:
- return MarginRankingLoss(margin=self.margin, reduction="mean")
-
- def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
- # Parameter initialization
- if margin is None or epsilon is None:
- # If no margin/epsilon are provided, use Xavier uniform initialization
- nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
- nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
- else:
- # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
- self.embedding_range = nn.Parameter(
- torch.Tensor([(margin + epsilon) / self.dim]),
- requires_grad=False,
- )
- nn.init.uniform_(
- tensor=self.ent_embeddings.weight.data,
- a=-self.embedding_range.item(),
- b=self.embedding_range.item(),
- )
- nn.init.uniform_(
- tensor=self.rel_embeddings.weight.data,
- a=-self.embedding_range.item(),
- b=self.embedding_range.item(),
- )
|