1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- """PyTorch Dataset wrapping indexed triples (h, r, t)."""
-
- from pathlib import Path
- from typing import Dict, List, Tuple
-
- import torch
- from pykeen.typing import MappedTriples
- from torch.utils.data import Dataset
-
-
- class KGDataset(Dataset):
- """PyTorch Dataset wrapping indexed triples (h, r, t)."""
-
- def __init__(
- self,
- triples_factory: MappedTriples,
- triples: List[Tuple[int, int, int]],
- num_entities: int,
- num_relations: int,
- split: str = "train",
- ) -> None:
- super().__init__()
- self.triples_factory = triples_factory
- self.triples = torch.tensor(triples, dtype=torch.long)
- self.num_entities = num_entities
- self.num_relations = num_relations
- self.split = split
-
- def __len__(self) -> int:
- """Return the number of triples in the dataset."""
- return len(self.triples)
-
- def __getitem__(self, idx: int) -> Tuple[int, int, int]:
- """Return the indexed triple at the given index."""
- return self.triples[idx]
-
- def entities(self) -> torch.Tensor:
- """Return a list of all entities in the dataset."""
- return torch.arange(self.num_entities)
-
- def relations(self) -> torch.Tensor:
- """Return a list of all relations in the dataset."""
- return torch.arange(self.num_relations)
-
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}(split={self.split!r}, "
- f"num_triples={len(self.triples)}, "
- f"num_entities={self.num_entities}, "
- f"num_relations={self.num_relations})"
- )
-
-
- # ---------------------------- Helper functions --------------------------------
-
-
- def _map_to_ids(
- triples: List[Tuple[str, str, str]],
- ) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]:
- ent2id: Dict[str, int] = {}
- rel2id: Dict[str, int] = {}
- mapped: List[Tuple[int, int, int]] = []
-
- def _id(d: Dict[str, int], k: str) -> int:
- d.setdefault(k, len(d))
- return d[k]
-
- for h, r, t in triples:
- mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t)))
-
- return mapped, ent2id, rel2id
-
-
- def load_kg_from_files(
- root: str, splits: tuple[str] = ("train", "valid", "test")
- ) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]:
- root = Path(root)
- split_raw, all_raw = {}, []
- for s in splits:
- lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()]
- split_raw[s] = lines
- all_raw.extend(lines)
- mapped, ent2id, rel2id = _map_to_ids(all_raw)
- datasets, start = {}, 0
- for s in splits:
- end = start + len(split_raw[s])
- datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s)
- start = end
- return datasets, ent2id, rel2id
|