from typing import List, Tuple import torch from pykeen.datasets import PathDataset from .kg_dataset import KGDataset class YAGO310Dataset(KGDataset): """Wrapper around PyKeen's YAGO3-10 dataset.""" def __init__(self, split: str = "train") -> None: # dataset = YAGO310() dataset = PathDataset( training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt", testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt", validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt", ) if split == "train": # if True: tf = dataset.training elif split in ("valid", "val", "validation"): tf = dataset.validation elif split in ("test", "testing"): tf = dataset.testing else: msg = f"Unrecognized split '{split}'" raise ValueError(msg) custom_triples = torch.load( "/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt", map_location=torch.device("cpu"), ).to(torch.long) # .tolist() # get a random sample of size 5000 # seed = torch.seed() # For reproducibility # torch.manual_seed(seed) # torch.cuda.manual_seed(seed) # if custom_triples.shape[0] > 5000: # custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]] tf = tf.clone_and_exchange_triples( custom_triples, ) # triples = tf.mapped_triples # # get a random sample of size 5000 # if triples.shape[0] > 5000: # indices = torch.randperm(triples.shape[0])[:5000] # triples = triples[indices] # tf = tf.clone_and_exchange_triples(triples) triples_list: List[Tuple[int, int, int]] = [ (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples ] super().__init__( triples_factory=tf, triples=triples_list, num_entities=dataset.num_entities, num_relations=dataset.num_relations, split=split, ) def __getitem__(self, idx: int) -> torch.Tensor: return torch.tensor(super().__getitem__(idx), dtype=torch.long)