12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- from typing import List, Tuple
-
- import torch
- from pykeen.datasets import PathDataset
-
- from .kg_dataset import KGDataset
-
-
- class YAGO310Dataset(KGDataset):
- """Wrapper around PyKeen's YAGO3-10 dataset."""
-
- def __init__(self, split: str = "train") -> None:
- # dataset = YAGO310()
-
- dataset = PathDataset(
- training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt",
- testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt",
- validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt",
- )
-
- if split == "train":
- # if True:
- tf = dataset.training
- elif split in ("valid", "val", "validation"):
- tf = dataset.validation
- elif split in ("test", "testing"):
- tf = dataset.testing
- else:
- msg = f"Unrecognized split '{split}'"
- raise ValueError(msg)
-
- custom_triples = torch.load(
- "/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt",
- map_location=torch.device("cpu"),
- ).to(torch.long)
- # .tolist()
-
- # get a random sample of size 5000
- # seed = torch.seed() # For reproducibility
- # torch.manual_seed(seed)
- # torch.cuda.manual_seed(seed)
- # if custom_triples.shape[0] > 5000:
- # custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]]
-
- tf = tf.clone_and_exchange_triples(
- custom_triples,
- )
-
- # triples = tf.mapped_triples
- # # get a random sample of size 5000
- # if triples.shape[0] > 5000:
- # indices = torch.randperm(triples.shape[0])[:5000]
- # triples = triples[indices]
-
- # tf = tf.clone_and_exchange_triples(triples)
-
- triples_list: List[Tuple[int, int, int]] = [
- (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
- ]
-
- super().__init__(
- triples_factory=tf,
- triples=triples_list,
- num_entities=dataset.num_entities,
- num_relations=dataset.num_relations,
- split=split,
- )
-
- def __getitem__(self, idx: int) -> torch.Tensor:
- return torch.tensor(super().__getitem__(idx), dtype=torch.long)
|