from typing import List, Tuple import torch from pykeen.datasets import OpenBioLink from .kg_dataset import KGDataset # Importing KGDataset from the same module class OpenBioLinkDataset(KGDataset): """A wrapper around PyKeen's OpenBioLink dataset that produces KGDataset instances.""" def __init__(self, split: str = "train") -> None: """ Initialize the WN18RR wrapper. :param split: One of "train", "valid", or "test" indicating which split to load. """ # Load the PyKeen WN18 dataset dataset = OpenBioLink() # Select the appropriate TriplesFactory based on the requested split if split == "train": tf = dataset.training elif split in ("valid", "val", "validation"): tf = dataset.validation elif split in ("test", "testing"): tf = dataset.testing else: msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'." raise ValueError( msg, ) # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64 # Convert each row into a Python tuple of ints triples_list: List[Tuple[int, int, int]] = [ (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples ] # Get the total number of entities and relations in the entire dataset num_entities = dataset.num_entities num_relations = dataset.num_relations # Initialize the base KGDataset with the selected triples super().__init__( triples_factory=tf, triples=triples_list, num_entities=num_entities, num_relations=num_relations, split=split, ) def __getitem__(self, idx: int) -> torch.Tensor: data = super().__getitem__(idx) # Convert the tuple to a torch tensor return torch.tensor(data, dtype=torch.long)