You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

wn18.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import List, Tuple
  2. import torch
  3. from pykeen.datasets import WN18, WN18RR
  4. from .kg_dataset import KGDataset # Importing KGDataset from the same module
  5. class WN18Dataset(KGDataset):
  6. """A wrapper around PyKeen's WN18 dataset that produces KGDataset instances."""
  7. def __init__(self, split: str = "train") -> None:
  8. """
  9. Initialize the WN18RR wrapper.
  10. :param split: One of "train", "valid", or "test" indicating which split to load.
  11. """
  12. # Load the PyKeen WN18 dataset
  13. dataset = WN18()
  14. # Select the appropriate TriplesFactory based on the requested split
  15. if split == "train":
  16. tf = dataset.training
  17. elif split in ("valid", "val", "validation"):
  18. tf = dataset.validation
  19. elif split in ("test", "testing"):
  20. tf = dataset.testing
  21. else:
  22. msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
  23. raise ValueError(
  24. msg,
  25. )
  26. # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
  27. # Convert each row into a Python tuple of ints
  28. triples_list: List[Tuple[int, int, int]] = [
  29. (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
  30. ]
  31. # Get the total number of entities and relations in the entire dataset
  32. num_entities = dataset.num_entities
  33. num_relations = dataset.num_relations
  34. # Initialize the base KGDataset with the selected triples
  35. super().__init__(
  36. triples_factory=tf,
  37. triples=triples_list,
  38. num_entities=num_entities,
  39. num_relations=num_relations,
  40. split=split,
  41. )
  42. def __getitem__(self, idx: int) -> torch.Tensor:
  43. data = super().__getitem__(idx)
  44. # Convert the tuple to a torch tensor
  45. return torch.tensor(data, dtype=torch.long)
  46. # Assuming KGDataset is defined in the same module or already imported:
  47. # from your_module import KGDataset
  48. class WN18RRDataset(KGDataset):
  49. """A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances."""
  50. def __init__(self, split: str = "train") -> None:
  51. """
  52. Initialize the WN18RR wrapper.
  53. :param split: One of "train", "valid", or "test" indicating which split to load.
  54. """
  55. # Load the PyKeen WN18RR dataset
  56. dataset = WN18RR()
  57. # Select the appropriate TriplesFactory based on the requested split
  58. if split == "train":
  59. tf = dataset.training
  60. elif split in ("valid", "val", "validation"):
  61. tf = dataset.validation
  62. elif split in ("test", "testing"):
  63. tf = dataset.testing
  64. else:
  65. msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
  66. raise ValueError(
  67. msg,
  68. )
  69. # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
  70. # Convert each row into a Python tuple of ints
  71. triples_list: List[Tuple[int, int, int]] = [
  72. (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
  73. ]
  74. # Get the total number of entities and relations in the entire dataset
  75. num_entities = dataset.num_entities
  76. num_relations = dataset.num_relations
  77. # Initialize the base KGDataset with the selected triples
  78. super().__init__(
  79. triples_factory=tf,
  80. triples=triples_list,
  81. num_entities=num_entities,
  82. num_relations=num_relations,
  83. split=split,
  84. )
  85. def __getitem__(self, idx: int) -> torch.Tensor:
  86. data = super().__getitem__(idx)
  87. # Convert the tuple to a torch tensor
  88. return torch.tensor(data, dtype=torch.long)