You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_h.py 2.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. import torch
  4. from pykeen.losses import MarginRankingLoss
  5. from pykeen.models import TransH as PyKEENTransH
  6. from pykeen.regularizers import LpRegularizer
  7. from torch import nn
  8. from models.base_model import ERModel
  9. if TYPE_CHECKING:
  10. from pykeen.typing import MappedTriples
  11. class TransH(ERModel):
  12. """
  13. TransH model for knowledge graph embedding.
  14. """
  15. def __init__(
  16. self,
  17. num_entities: int,
  18. num_relations: int,
  19. triples_factory: MappedTriples,
  20. dim: int = 200,
  21. p_norm: bool = True,
  22. margin: float | None = None,
  23. epsilon: float | None = None,
  24. device: torch.device = torch.device("cpu"),
  25. **kwargs,
  26. ) -> None:
  27. super().__init__(num_entities, num_relations, dim)
  28. self.dim = dim
  29. self.p_norm = p_norm
  30. self.margin = margin
  31. self.epsilon = epsilon
  32. self.inner_model = PyKEENTransH(
  33. triples_factory=triples_factory,
  34. embedding_dim=dim,
  35. scoring_fct_norm=1,
  36. loss=self.loss,
  37. power_norm=p_norm,
  38. entity_initializer="xavier_uniform_", # good default
  39. relation_initializer="xavier_uniform_",
  40. random_seed=42,
  41. ).to(device)
  42. @property
  43. def loss(self) -> torch.Tensor:
  44. return MarginRankingLoss(margin=self.margin, reduction="mean")
  45. def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
  46. # Parameter initialization
  47. if margin is None or epsilon is None:
  48. # If no margin/epsilon are provided, use Xavier uniform initialization
  49. nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
  50. nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
  51. else:
  52. # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
  53. self.embedding_range = nn.Parameter(
  54. torch.Tensor([(margin + epsilon) / self.dim]),
  55. requires_grad=False,
  56. )
  57. nn.init.uniform_(
  58. tensor=self.ent_embeddings.weight.data,
  59. a=-self.embedding_range.item(),
  60. b=self.embedding_range.item(),
  61. )
  62. nn.init.uniform_(
  63. tensor=self.rel_embeddings.weight.data,
  64. a=-self.embedding_range.item(),
  65. b=self.embedding_range.item(),
  66. )