You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yago3_10.py 2.3KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from typing import List, Tuple
  2. import torch
  3. from pykeen.datasets import PathDataset
  4. from .kg_dataset import KGDataset
  5. class YAGO310Dataset(KGDataset):
  6. """Wrapper around PyKeen's YAGO3-10 dataset."""
  7. def __init__(self, split: str = "train") -> None:
  8. # dataset = YAGO310()
  9. dataset = PathDataset(
  10. training_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/train.txt",
  11. testing_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/test.txt",
  12. validation_path="/home/n_kazemi/.data/pykeen/datasets/yago310/YAGO3-10/valid.txt",
  13. )
  14. if split == "train":
  15. # if True:
  16. tf = dataset.training
  17. elif split in ("valid", "val", "validation"):
  18. tf = dataset.validation
  19. elif split in ("test", "testing"):
  20. tf = dataset.testing
  21. else:
  22. msg = f"Unrecognized split '{split}'"
  23. raise ValueError(msg)
  24. custom_triples = torch.load(
  25. "/home/n_kazemi/projects/KGEvaluation/CustomDataset/yago310/yago310_crec_bi_1.pt",
  26. map_location=torch.device("cpu"),
  27. ).to(torch.long)
  28. # .tolist()
  29. # get a random sample of size 5000
  30. # seed = torch.seed() # For reproducibility
  31. # torch.manual_seed(seed)
  32. # torch.cuda.manual_seed(seed)
  33. # if custom_triples.shape[0] > 5000:
  34. # custom_triples = custom_triples[torch.randperm(custom_triples.shape[0])[:5000]]
  35. tf = tf.clone_and_exchange_triples(
  36. custom_triples,
  37. )
  38. # triples = tf.mapped_triples
  39. # # get a random sample of size 5000
  40. # if triples.shape[0] > 5000:
  41. # indices = torch.randperm(triples.shape[0])[:5000]
  42. # triples = triples[indices]
  43. # tf = tf.clone_and_exchange_triples(triples)
  44. triples_list: List[Tuple[int, int, int]] = [
  45. (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
  46. ]
  47. super().__init__(
  48. triples_factory=tf,
  49. triples=triples_list,
  50. num_entities=dataset.num_entities,
  51. num_relations=dataset.num_relations,
  52. split=split,
  53. )
  54. def __getitem__(self, idx: int) -> torch.Tensor:
  55. return torch.tensor(super().__getitem__(idx), dtype=torch.long)