You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_r.py 2.6KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. import torch
  4. from pykeen.losses import MarginRankingLoss
  5. from pykeen.models import TransR as PyKEENTransR
  6. from torch import nn
  7. from models.base_model import ERModel
  8. if TYPE_CHECKING:
  9. from pykeen.typing import MappedTriples
  10. class TransR(ERModel):
  11. """
  12. TransR model for knowledge graph embedding.
  13. """
  14. def __init__(
  15. self,
  16. num_entities: int,
  17. num_relations: int,
  18. triples_factory: MappedTriples,
  19. entity_dim: int = 200,
  20. relation_dim: int = 30,
  21. p_norm: bool = True,
  22. p_norm_value: int = 2,
  23. margin: float | None = None,
  24. epsilon: float | None = None,
  25. device: torch.device = torch.device("cpu"),
  26. **kwargs,
  27. ) -> None:
  28. super().__init__(num_entities, num_relations, entity_dim)
  29. self.entity_dim = entity_dim
  30. self.relation_dim = relation_dim
  31. self.p_norm = p_norm
  32. self.margin = margin
  33. self.epsilon = epsilon
  34. self.inner_model = PyKEENTransR(
  35. triples_factory=triples_factory,
  36. embedding_dim=entity_dim,
  37. relation_dim=relation_dim,
  38. scoring_fct_norm=1,
  39. loss=self.loss,
  40. power_norm=p_norm,
  41. entity_initializer="xavier_uniform_",
  42. relation_initializer="xavier_uniform_",
  43. random_seed=42,
  44. ).to(device)
  45. @property
  46. def loss(self) -> torch.Tensor:
  47. return MarginRankingLoss(margin=self.margin, reduction="mean")
  48. def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
  49. # Parameter initialization
  50. if margin is None or epsilon is None:
  51. # If no margin/epsilon are provided, use Xavier uniform initialization
  52. nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
  53. nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
  54. else:
  55. # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
  56. self.embedding_range = nn.Parameter(
  57. torch.Tensor([(margin + epsilon) / self.dim]),
  58. requires_grad=False,
  59. )
  60. nn.init.uniform_(
  61. tensor=self.ent_embeddings.weight.data,
  62. a=-self.embedding_range.item(),
  63. b=self.embedding_range.item(),
  64. )
  65. nn.init.uniform_(
  66. tensor=self.rel_embeddings.weight.data,
  67. a=-self.embedding_range.item(),
  68. b=self.embedding_range.item(),
  69. )