12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- from typing import List, Tuple
-
- import torch
- from pykeen.datasets import FB15k
-
- from .kg_dataset import KGDataset
-
-
- class FB15KDataset(KGDataset):
- """A wrapper around PyKeen's FB15k dataset producing KGDataset instances."""
-
- def __init__(self, split: str = "train") -> None:
- dataset = FB15k()
-
- # if split == "train":
- if True:
- tf = dataset.training
- elif split in ("valid", "val", "validation"):
- tf = dataset.validation
- elif split in ("test", "testing"):
- tf = dataset.testing
- else:
- msg = f"Unrecognized split '{split}'"
- raise ValueError(msg)
-
- custom_triples = torch.load(
- "/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt",
- map_location=torch.device("cpu"),
- ).to(torch.long)
- # .tolist()
-
- tf = tf.clone_and_exchange_triples(
- custom_triples,
- )
-
- triples_list: List[Tuple[int, int, int]] = [
- (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
- ]
-
- super().__init__(
- triples_factory=tf,
- triples=triples_list,
- num_entities=dataset.num_entities,
- num_relations=dataset.num_relations,
- split=split,
- )
-
- def __getitem__(self, idx: int) -> torch.Tensor:
- return torch.tensor(super().__getitem__(idx), dtype=torch.long)
|