123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- from typing import List, Tuple
-
- import torch
- from pykeen.datasets import OpenBioLink
-
- from .kg_dataset import KGDataset # Importing KGDataset from the same module
-
-
- class OpenBioLinkDataset(KGDataset):
- """A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances."""
-
- def __init__(self, split: str = "train") -> None:
- """
- Initialize the WN18RR wrapper.
-
- :param split: One of "train", "valid", or "test" indicating which split to load.
- """
- # Load the PyKeen WN18 dataset
- dataset = OpenBioLink()
-
- # Select the appropriate TriplesFactory based on the requested split
- if split == "train":
- tf = dataset.training
- elif split in ("valid", "val", "validation"):
- tf = dataset.validation
- elif split in ("test", "testing"):
- tf = dataset.testing
- else:
- msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
- raise ValueError(
- msg,
- )
-
- # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
- # Convert each row into a Python tuple of ints
- triples_list: List[Tuple[int, int, int]] = [
- (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
- ]
-
- # Get the total number of entities and relations in the entire dataset
- num_entities = dataset.num_entities
- num_relations = dataset.num_relations
-
- # Initialize the base KGDataset with the selected triples
- super().__init__(
- triples_factory=tf,
- triples=triples_list,
- num_entities=num_entities,
- num_relations=num_relations,
- split=split,
- )
-
- def __getitem__(self, idx: int) -> torch.Tensor:
- data = super().__getitem__(idx)
-
- # Convert the tuple to a torch tensor
- return torch.tensor(data, dtype=torch.long)
|