"""PyTorch Dataset wrapping indexed triples (h, r, t).""" from pathlib import Path from typing import Dict, List, Tuple import torch from pykeen.typing import MappedTriples from torch.utils.data import Dataset class KGDataset(Dataset): """PyTorch Dataset wrapping indexed triples (h, r, t).""" def __init__( self, triples_factory: MappedTriples, triples: List[Tuple[int, int, int]], num_entities: int, num_relations: int, split: str = "train", ) -> None: super().__init__() self.triples_factory = triples_factory self.triples = torch.tensor(triples, dtype=torch.long) self.num_entities = num_entities self.num_relations = num_relations self.split = split def __len__(self) -> int: """Return the number of triples in the dataset.""" return len(self.triples) def __getitem__(self, idx: int) -> Tuple[int, int, int]: """Return the indexed triple at the given index.""" return self.triples[idx] def entities(self) -> torch.Tensor: """Return a list of all entities in the dataset.""" return torch.arange(self.num_entities) def relations(self) -> torch.Tensor: """Return a list of all relations in the dataset.""" return torch.arange(self.num_relations) def __repr__(self) -> str: return ( f"{self.__class__.__name__}(split={self.split!r}, " f"num_triples={len(self.triples)}, " f"num_entities={self.num_entities}, " f"num_relations={self.num_relations})" ) # ---------------------------- Helper functions -------------------------------- def _map_to_ids( triples: List[Tuple[str, str, str]], ) -> Tuple[List[Tuple[int, int, int]], Dict[str, int], Dict[str, int]]: ent2id: Dict[str, int] = {} rel2id: Dict[str, int] = {} mapped: List[Tuple[int, int, int]] = [] def _id(d: Dict[str, int], k: str) -> int: d.setdefault(k, len(d)) return d[k] for h, r, t in triples: mapped.append((_id(ent2id, h), _id(rel2id, r), _id(ent2id, t))) return mapped, ent2id, rel2id def load_kg_from_files( root: str, splits: tuple[str] = ("train", "valid", "test") ) -> Tuple[Dict[str, KGDataset], Dict[str, int], Dict[str, int]]: root = Path(root) split_raw, all_raw = {}, [] for s in splits: lines = [tuple(line.split("\t")) for line in (root / f"{s}.txt").read_text().splitlines()] split_raw[s] = lines all_raw.extend(lines) mapped, ent2id, rel2id = _map_to_ids(all_raw) datasets, start = {}, 0 for s in splits: end = start + len(split_raw[s]) datasets[s] = KGDataset(mapped[start:end], len(ent2id), len(rel2id), s) start = end return datasets, ent2id, rel2id