You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fb15k.py 1.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import List, Tuple
  2. import torch
  3. from pykeen.datasets import FB15k
  4. from .kg_dataset import KGDataset
  5. class FB15KDataset(KGDataset):
  6. """A wrapper around PyKeen's FB15k dataset producing KGDataset instances."""
  7. def __init__(self, split: str = "train") -> None:
  8. dataset = FB15k()
  9. # if split == "train":
  10. if True:
  11. tf = dataset.training
  12. elif split in ("valid", "val", "validation"):
  13. tf = dataset.validation
  14. elif split in ("test", "testing"):
  15. tf = dataset.testing
  16. else:
  17. msg = f"Unrecognized split '{split}'"
  18. raise ValueError(msg)
  19. custom_triples = torch.load(
  20. "/home/n_kazemi/projects/KGEvaluation/CustomDataset/fb15k/fb15k_crec_bi_1.pt",
  21. map_location=torch.device("cpu"),
  22. ).to(torch.long)
  23. # .tolist()
  24. tf = tf.clone_and_exchange_triples(
  25. custom_triples,
  26. )
  27. triples_list: List[Tuple[int, int, int]] = [
  28. (int(h), int(r), int(t)) for h, r, t in tf.mapped_triples
  29. ]
  30. super().__init__(
  31. triples_factory=tf,
  32. triples=triples_list,
  33. num_entities=dataset.num_entities,
  34. num_relations=dataset.num_relations,
  35. split=split,
  36. )
  37. def __getitem__(self, idx: int) -> torch.Tensor:
  38. return torch.tensor(super().__getitem__(idx), dtype=torch.long)