123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import torch
- import numpy as np
- from torch.utils.data.sampler import Sampler
-
-
- def negative_sampling(
- pos: torch.Tensor,
- num_entities: int,
- num_negatives: int = 1,
- mode: list = ["head", "tail"],
- ) -> torch.Tensor:
- pos = pos.repeat_interleave(num_negatives, dim=0)
- if "head" in mode:
- neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
- pos[:, 0] = neg
- if "tail" in mode:
- neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
- pos[:, 2] = neg
- return pos
-
-
- class DeterministicRandomSampler(Sampler):
- """
- A deterministic random sampler that selects a fixed number of samples
- in a reproducible manner using a given seed.
- """
-
- def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None:
- """
- Args:
- dataset (Dataset): PyTorch dataset.
- sample_size (int, optional): Number of samples to draw. If None, use full dataset.
- seed (int): Seed for reproducibility.
- """
- self.dataset_size = dataset_size
- self.seed = seed
- self.sample_size = sample_size if sample_size != 0 else self.dataset_size
-
- # Ensure sample size is within dataset size
- if self.sample_size > self.dataset_size:
- msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}."
- raise ValueError(
- msg,
- )
-
- self.indices = self._generate_deterministic_indices()
-
- def _generate_deterministic_indices(self) -> list[int]:
- """
- Generates a fixed random subset of indices using the given seed.
- """
- rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility
- all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices
- return all_indices[: self.sample_size].tolist() # Select only the desired number of samples
-
- def __iter__(self) -> iter:
- """
- Yields the shuffled dataset indices.
- """
- return iter(self.indices)
-
- def __len__(self) -> int:
- """
- Returns the total number of samples drawn.
- """
- return self.sample_size
|