You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

kg_dataset.py 2.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """PyTorch Dataset wrapping indexed triples (h, r, t)."""
  2. from pathlib import Path
  3. from typing import Dict, List, Tuple
  4. import torch
  5. from pykeen.typing import MappedTriples
  6. from torch.utils.data import Dataset
  7. class KGDataset(Dataset):
  8. """PyTorch Dataset wrapping indexed triples (h, r, t)."""
  9. def __init__(
  10. self,
  11. triples_factory: MappedTriples,
  12. triples: List[Tuple[int, int, int]],
  13. num_entities: int,
  14. num_relations: int,
  15. split: str = "train",
  16. ) -> None:
  17. super().__init__()
  18. self.triples_factory = triples_factory
  19. self.triples = torch.tensor(triples, dtype=torch.long)
  20. self.num_entities = num_entities
  21. self.num_relations = num_relations
  22. self.split = split
  23. def __len__(self) -> int:
  24. """Return the number of triples in the dataset."""
  25. return len(self.triples)
  26. def __getitem__(self, idx: int) -> Tuple[int, int, int]:
  27. """Return the indexed triple at the given index."""
  28. return self.triples[idx]
  29. def entities(self) -> torch.Tensor:
  30. """Return a list of all entities in the dataset."""
  31. return torch.arange(self.num_entities)
  32. def relations(self) -> torch.Tensor:
  33. """Return a list of all relations in the dataset."""
  34. return torch.arange(self.num_relations)
  35. def __repr__(self) -> str:
  36. return (
  37. f"{self.__class__.__name__}(split={self.split!r}, "
  38. f"num_triples={len(self.triples)}, "
  39. f"num_entities={self.num_entities}, "
  40. f"num_relations={self.num_relations})"
  41. )
  42. # ---------------------------- Helper functions --------------------------------
  43. def _map_to_ids(
  44. triples: List[Tuple[str, str, str]],
  45. ) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]:
  46. ent2id: Dict[str, int] = {}
  47. rel2id: Dict[str, int] = {}
  48. mapped: List[Tuple[int, int, int]] = []
  49. def _id(d: Dict[str, int], k: str) -> int:
  50. d.setdefault(k, len(d))
  51. return d[k]
  52. for h, r, t in triples:
  53. mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t)))
  54. return mapped, ent2id, rel2id
  55. def load_kg_from_files(
  56. root: str, splits: tuple[str] = ("train", "valid", "test")
  57. ) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]:
  58. root = Path(root)
  59. split_raw, all_raw = {}, []
  60. for s in splits:
  61. lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()]
  62. split_raw[s] = lines
  63. all_raw.extend(lines)
  64. mapped, ent2id, rel2id = _map_to_ids(all_raw)
  65. datasets, start = {}, 0
  66. for s in splits:
  67. end = start + len(split_raw[s])
  68. datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s)
  69. start = end
  70. return datasets, ent2id, rel2id