You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

open_bio_link.py 1.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. from typing import List, Tuple
  2. import torch
  3. from pykeen.datasets import OpenBioLink
  4. from .kg_dataset import KGDataset # Importing KGDataset from the same module
  5. class OpenBioLinkDataset(KGDataset):
  6. """A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances."""
  7. def __init__(self, split: str = "train") -> None:
  8. """
  9. Initialize the WN18RR wrapper.
  10. :param split: One of "train", "valid", or "test" indicating which split to load.
  11. """
  12. # Load the PyKeen WN18 dataset
  13. dataset = OpenBioLink()
  14. # Select the appropriate TriplesFactory based on the requested split
  15. if split == "train":
  16. tf = dataset.training
  17. elif split in ("valid", "val", "validation"):
  18. tf = dataset.validation
  19. elif split in ("test", "testing"):
  20. tf = dataset.testing
  21. else:
  22. msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
  23. raise ValueError(
  24. msg,
  25. )
  26. # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
  27. # Convert each row into a Python tuple of ints
  28. triples_list: List[Tuple[int, int, int]] = [
  29. (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
  30. ]
  31. # Get the total number of entities and relations in the entire dataset
  32. num_entities = dataset.num_entities
  33. num_relations = dataset.num_relations
  34. # Initialize the base KGDataset with the selected triples
  35. super().__init__(
  36. triples_factory=tf,
  37. triples=triples_list,
  38. num_entities=num_entities,
  39. num_relations=num_relations,
  40. split=split,
  41. )
  42. def __getitem__(self, idx: int) -> torch.Tensor:
  43. data = super().__getitem__(idx)
  44. # Convert the tuple to a torch tensor
  45. return torch.tensor(data, dtype=torch.long)