123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- from typing import List, Tuple
-
- import torch
- from pykeen.datasets import WN18, WN18RR
-
- from .kg_dataset import KGDataset # Importing KGDataset from the same module
-
-
- class WN18Dataset(KGDataset):
- """A wrapper around PyKeen's WN18 dataset that produces KGDataset instances."""
-
- def __init__(self, split: str = "train") -> None:
- """
- Initialize the WN18RR wrapper.
-
- :param split: One of "train", "valid", or "test" indicating which split to load.
- """
- # Load the PyKeen WN18 dataset
- dataset = WN18()
-
- # Select the appropriate TriplesFactory based on the requested split
- if split == "train":
- tf = dataset.training
- elif split in ("valid", "val", "validation"):
- tf = dataset.validation
- elif split in ("test", "testing"):
- tf = dataset.testing
- else:
- msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
- raise ValueError(
- msg,
- )
-
- # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
- # Convert each row into a Python tuple of ints
- triples_list: List[Tuple[int, int, int]] = [
- (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
- ]
-
- # Get the total number of entities and relations in the entire dataset
- num_entities = dataset.num_entities
- num_relations = dataset.num_relations
-
- # Initialize the base KGDataset with the selected triples
- super().__init__(
- triples_factory=tf,
- triples=triples_list,
- num_entities=num_entities,
- num_relations=num_relations,
- split=split,
- )
-
- def __getitem__(self, idx: int) -> torch.Tensor:
- data = super().__getitem__(idx)
-
- # Convert the tuple to a torch tensor
- return torch.tensor(data, dtype=torch.long)
-
-
- # Assuming KGDataset is defined in the same module or already imported:
- # from your_module import KGDataset
-
-
- class WN18RRDataset(KGDataset):
- """A wrapper around PyKeen's WN18RR dataset that produces KGDataset instances."""
-
- def __init__(self, split: str = "train") -> None:
- """
- Initialize the WN18RR wrapper.
-
- :param split: One of "train", "valid", or "test" indicating which split to load.
- """
- # Load the PyKeen WN18RR dataset
- dataset = WN18RR()
-
- # Select the appropriate TriplesFactory based on the requested split
- if split == "train":
- tf = dataset.training
- elif split in ("valid", "val", "validation"):
- tf = dataset.validation
- elif split in ("test", "testing"):
- tf = dataset.testing
- else:
- msg = f"Unrecognized split '{split}'. Must be one of 'train', 'valid', or 'test'."
- raise ValueError(
- msg,
- )
-
- # `tf.mapped_triples` is a NumPy array of shape (n_triples, 3) with dtype=int64
- # Convert each row into a Python tuple of ints
- triples_list: List[Tuple[int, int, int]] = [
- (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
- ]
-
- # Get the total number of entities and relations in the entire dataset
- num_entities = dataset.num_entities
- num_relations = dataset.num_relations
-
- # Initialize the base KGDataset with the selected triples
- super().__init__(
- triples_factory=tf,
- triples=triples_list,
- num_entities=num_entities,
- num_relations=num_relations,
- split=split,
- )
-
- def __getitem__(self, idx: int) -> torch.Tensor:
- data = super().__getitem__(idx)
-
- # Convert the tuple to a torch tensor
- return torch.tensor(data, dtype=torch.long)
|