Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dist.py 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # This file mainly comes from
  4. # https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
  5. # Copyright (c) Facebook, Inc. and its affiliates.
  6. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. """
  8. This file contains primitives for multi-gpu communication.
  9. This is useful when doing distributed training.
  10. """
  11. import numpy as np
  12. import torch
  13. from torch import distributed as dist
  14. import functools
  15. import logging
  16. import pickle
  17. import time
  18. __all__ = [
  19. "is_main_process",
  20. "synchronize",
  21. "get_world_size",
  22. "get_rank",
  23. "get_local_rank",
  24. "get_local_size",
  25. "time_synchronized",
  26. "gather",
  27. "all_gather",
  28. ]
  29. _LOCAL_PROCESS_GROUP = None
  30. def synchronize():
  31. """
  32. Helper function to synchronize (barrier) among all processes when using distributed training
  33. """
  34. if not dist.is_available():
  35. return
  36. if not dist.is_initialized():
  37. return
  38. world_size = dist.get_world_size()
  39. if world_size == 1:
  40. return
  41. dist.barrier()
  42. def get_world_size() -> int:
  43. if not dist.is_available():
  44. return 1
  45. if not dist.is_initialized():
  46. return 1
  47. return dist.get_world_size()
  48. def get_rank() -> int:
  49. if not dist.is_available():
  50. return 0
  51. if not dist.is_initialized():
  52. return 0
  53. return dist.get_rank()
  54. def get_local_rank() -> int:
  55. """
  56. Returns:
  57. The rank of the current process within the local (per-machine) process group.
  58. """
  59. if not dist.is_available():
  60. return 0
  61. if not dist.is_initialized():
  62. return 0
  63. assert _LOCAL_PROCESS_GROUP is not None
  64. return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
  65. def get_local_size() -> int:
  66. """
  67. Returns:
  68. The size of the per-machine process group, i.e. the number of processes per machine.
  69. """
  70. if not dist.is_available():
  71. return 1
  72. if not dist.is_initialized():
  73. return 1
  74. return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
  75. def is_main_process() -> bool:
  76. return get_rank() == 0
  77. @functools.lru_cache()
  78. def _get_global_gloo_group():
  79. """
  80. Return a process group based on gloo backend, containing all the ranks
  81. The result is cached.
  82. """
  83. if dist.get_backend() == "nccl":
  84. return dist.new_group(backend="gloo")
  85. else:
  86. return dist.group.WORLD
  87. def _serialize_to_tensor(data, group):
  88. backend = dist.get_backend(group)
  89. assert backend in ["gloo", "nccl"]
  90. device = torch.device("cpu" if backend == "gloo" else "cuda")
  91. buffer = pickle.dumps(data)
  92. if len(buffer) > 1024 ** 3:
  93. logger = logging.getLogger(__name__)
  94. logger.warning(
  95. "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
  96. get_rank(), len(buffer) / (1024 ** 3), device
  97. )
  98. )
  99. storage = torch.ByteStorage.from_buffer(buffer)
  100. tensor = torch.ByteTensor(storage).to(device=device)
  101. return tensor
  102. def _pad_to_largest_tensor(tensor, group):
  103. """
  104. Returns:
  105. list[int]: size of the tensor, on each rank
  106. Tensor: padded tensor that has the max size
  107. """
  108. world_size = dist.get_world_size(group=group)
  109. assert (
  110. world_size >= 1
  111. ), "comm.gather/all_gather must be called from ranks within the given group!"
  112. local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
  113. size_list = [
  114. torch.zeros([1], dtype=torch.int64, device=tensor.device)
  115. for _ in range(world_size)
  116. ]
  117. dist.all_gather(size_list, local_size, group=group)
  118. size_list = [int(size.item()) for size in size_list]
  119. max_size = max(size_list)
  120. # we pad the tensor because torch all_gather does not support
  121. # gathering tensors of different shapes
  122. if local_size != max_size:
  123. padding = torch.zeros(
  124. (max_size - local_size,), dtype=torch.uint8, device=tensor.device
  125. )
  126. tensor = torch.cat((tensor, padding), dim=0)
  127. return size_list, tensor
  128. def all_gather(data, group=None):
  129. """
  130. Run all_gather on arbitrary picklable data (not necessarily tensors).
  131. Args:
  132. data: any picklable object
  133. group: a torch process group. By default, will use a group which
  134. contains all ranks on gloo backend.
  135. Returns:
  136. list[data]: list of data gathered from each rank
  137. """
  138. if get_world_size() == 1:
  139. return [data]
  140. if group is None:
  141. group = _get_global_gloo_group()
  142. if dist.get_world_size(group) == 1:
  143. return [data]
  144. tensor = _serialize_to_tensor(data, group)
  145. size_list, tensor = _pad_to_largest_tensor(tensor, group)
  146. max_size = max(size_list)
  147. # receiving Tensor from all ranks
  148. tensor_list = [
  149. torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
  150. for _ in size_list
  151. ]
  152. dist.all_gather(tensor_list, tensor, group=group)
  153. data_list = []
  154. for size, tensor in zip(size_list, tensor_list):
  155. buffer = tensor.cpu().numpy().tobytes()[:size]
  156. data_list.append(pickle.loads(buffer))
  157. return data_list
  158. def gather(data, dst=0, group=None):
  159. """
  160. Run gather on arbitrary picklable data (not necessarily tensors).
  161. Args:
  162. data: any picklable object
  163. dst (int): destination rank
  164. group: a torch process group. By default, will use a group which
  165. contains all ranks on gloo backend.
  166. Returns:
  167. list[data]: on dst, a list of data gathered from each rank. Otherwise,
  168. an empty list.
  169. """
  170. if get_world_size() == 1:
  171. return [data]
  172. if group is None:
  173. group = _get_global_gloo_group()
  174. if dist.get_world_size(group=group) == 1:
  175. return [data]
  176. rank = dist.get_rank(group=group)
  177. tensor = _serialize_to_tensor(data, group)
  178. size_list, tensor = _pad_to_largest_tensor(tensor, group)
  179. # receiving Tensor from all ranks
  180. if rank == dst:
  181. max_size = max(size_list)
  182. tensor_list = [
  183. torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
  184. for _ in size_list
  185. ]
  186. dist.gather(tensor, tensor_list, dst=dst, group=group)
  187. data_list = []
  188. for size, tensor in zip(size_list, tensor_list):
  189. buffer = tensor.cpu().numpy().tobytes()[:size]
  190. data_list.append(pickle.loads(buffer))
  191. return data_list
  192. else:
  193. dist.gather(tensor, [], dst=dst, group=group)
  194. return []
  195. def shared_random_seed():
  196. """
  197. Returns:
  198. int: a random number that is the same across all workers.
  199. If workers need a shared RNG, they can use this shared seed to
  200. create one.
  201. All workers must call this function, otherwise it will deadlock.
  202. """
  203. ints = np.random.randint(2 ** 31)
  204. all_ints = all_gather(ints)
  205. return all_ints[0]
  206. def time_synchronized():
  207. """pytorch-accurate time"""
  208. if torch.cuda.is_available():
  209. torch.cuda.synchronize()
  210. return time.time()