You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trans_e.py 2.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. import torch
  4. from pykeen.models import TransE as PyKEENTransE
  5. from torch import nn
  6. from models.base_model import ERModel
  7. if TYPE_CHECKING:
  8. from pykeen.typing import MappedTriples
  9. class TransE(ERModel):
  10. """
  11. TransE model for knowledge graph embedding.
  12. score = ||h + r - t||
  13. """
  14. def __init__(
  15. self,
  16. num_entities: int,
  17. num_relations: int,
  18. triples_factory: MappedTriples,
  19. dim: int = 200,
  20. sf_norm: int = 1,
  21. p_norm: bool = True,
  22. margin: float | None = None,
  23. epsilon: float | None = None,
  24. device: torch.device = torch.device("cpu"),
  25. ) -> None:
  26. super().__init__(num_entities, num_relations, dim)
  27. self.dim = dim
  28. self.sf_norm = sf_norm
  29. self.p_norm = p_norm
  30. self.margin = margin
  31. self.epsilon = epsilon
  32. self.inner_model = PyKEENTransE(
  33. triples_factory=triples_factory,
  34. embedding_dim=dim,
  35. scoring_fct_norm=sf_norm,
  36. power_norm=p_norm,
  37. random_seed=42,
  38. ).to(device)
  39. def reset_parameters(self, margin: float | None = None, epsilon: float | None = None) -> None:
  40. # Parameter initialization
  41. if margin is None or epsilon is None:
  42. # If no margin/epsilon are provided, use Xavier uniform initialization
  43. nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
  44. nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
  45. else:
  46. # Otherwise, initialize uniformly in [ -(margin+epsilon)/dim , +(margin+epsilon)/dim ]
  47. self.embedding_range = nn.Parameter(
  48. torch.Tensor([(margin + epsilon) / self.dim]),
  49. requires_grad=False,
  50. )
  51. nn.init.uniform_(
  52. tensor=self.ent_embeddings.weight.data,
  53. a=-self.embedding_range.item(),
  54. b=self.embedding_range.item(),
  55. )
  56. nn.init.uniform_(
  57. tensor=self.rel_embeddings.weight.data,
  58. a=-self.embedding_range.item(),
  59. b=self.embedding_range.item(),
  60. )