from typing import List, Tuple import torch from pykeen.datasets import FB15k from .kg_dataset import KGDataset class FB15KDataset(KGDataset): """A wrapper around PyKeen's FB15k dataset producing KGDataset instances.""" def __init__(self, split: str = "train") -> None: dataset = FB15k() # if split == "train": if True: tf = dataset.training elif split in ("valid", "val", "validation"): tf = dataset.validation elif split in ("test", "testing"): tf = dataset.testing else: msg = f"Unrecognized split '{split}'" raise ValueError(msg) custom_triples = torch.load( "/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt", map_location=torch.device("cpu"), ).to(torch.long) # .tolist() tf = tf.clone_and_exchange_triples( custom_triples, ) triples_list: List[Tuple[int, int, int]] = [ (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples ] super().__init__( triples_factory=tf, triples=triples_list, num_entities=dataset.num_entities, num_relations=dataset.num_relations, split=split, ) def __getitem__(self, idx: int) -> torch.Tensor: return torch.tensor(super().__getitem__(idx), dtype=torch.long)