You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampling.py 2.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import torch
  2. import numpy as np
  3. from torch.utils.data.sampler import Sampler
  4. def negative_sampling(
  5. pos: torch.Tensor,
  6. num_entities: int,
  7. num_negatives: int = 1,
  8. mode: list = ["head", "tail"],
  9. ) -> torch.Tensor:
  10. pos = pos.repeat_interleave(num_negatives, dim=0)
  11. if "head" in mode:
  12. neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
  13. pos[:, 0] = neg
  14. if "tail" in mode:
  15. neg = torch.randint(0, num_entities, (pos.shape[0],), device=pos.device)
  16. pos[:, 2] = neg
  17. return pos
  18. class DeterministicRandomSampler(Sampler):
  19. """
  20. A deterministic random sampler that selects a fixed number of samples
  21. in a reproducible manner using a given seed.
  22. """
  23. def __init__(self, dataset_size: int, sample_size: int = 0, seed: int = 42) -> None:
  24. """
  25. Args:
  26. dataset (Dataset): PyTorch dataset.
  27. sample_size (int, optional): Number of samples to draw. If None, use full dataset.
  28. seed (int): Seed for reproducibility.
  29. """
  30. self.dataset_size = dataset_size
  31. self.seed = seed
  32. self.sample_size = sample_size if sample_size != 0 else self.dataset_size
  33. # Ensure sample size is within dataset size
  34. if self.sample_size > self.dataset_size:
  35. msg = f"Sample size {self.sample_size} exceeds dataset size {self.dataset_size}."
  36. raise ValueError(
  37. msg,
  38. )
  39. self.indices = self._generate_deterministic_indices()
  40. def _generate_deterministic_indices(self) -> list[int]:
  41. """
  42. Generates a fixed random subset of indices using the given seed.
  43. """
  44. rng = np.random.default_rng(self.seed) # NumPy's Generator for better reproducibility
  45. all_indices = rng.permutation(self.dataset_size) # Shuffle full dataset indices
  46. return all_indices[: self.sample_size].tolist() # Select only the desired number of samples
  47. def __iter__(self) -> iter:
  48. """
  49. Yields the shuffled dataset indices.
  50. """
  51. return iter(self.indices)
  52. def __len__(self) -> int:
  53. """
  54. Returns the total number of samples drawn.
  55. """
  56. return self.sample_size