Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_prefetcher.py 2.2KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. import torch
  5. import torch.distributed as dist
  6. from yolox.utils import synchronize
  7. import random
  8. class DataPrefetcher:
  9. """
  10. DataPrefetcher is inspired by code of following file:
  11. https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
  12. It could speedup your pytorch dataloader. For more information, please check
  13. https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
  14. """
  15. def __init__(self, loader):
  16. self.loader = iter(loader)
  17. self.stream = torch.cuda.Stream()
  18. self.input_cuda = self._input_cuda_for_image
  19. self.record_stream = DataPrefetcher._record_stream_for_image
  20. self.preload()
  21. def preload(self):
  22. try:
  23. self.next_input, self.next_target, _, _ = next(self.loader)
  24. except StopIteration:
  25. self.next_input = None
  26. self.next_target = None
  27. return
  28. with torch.cuda.stream(self.stream):
  29. self.input_cuda()
  30. self.next_target = self.next_target.cuda(non_blocking=True)
  31. def next(self):
  32. torch.cuda.current_stream().wait_stream(self.stream)
  33. input = self.next_input
  34. target = self.next_target
  35. if input is not None:
  36. self.record_stream(input)
  37. if target is not None:
  38. target.record_stream(torch.cuda.current_stream())
  39. self.preload()
  40. return input, target
  41. def _input_cuda_for_image(self):
  42. self.next_input = self.next_input.cuda(non_blocking=True)
  43. @staticmethod
  44. def _record_stream_for_image(input):
  45. input.record_stream(torch.cuda.current_stream())
  46. def random_resize(data_loader, exp, epoch, rank, is_distributed):
  47. tensor = torch.LongTensor(1).cuda()
  48. if is_distributed:
  49. synchronize()
  50. if rank == 0:
  51. if epoch > exp.max_epoch - 10:
  52. size = exp.input_size
  53. else:
  54. size = random.randint(*exp.random_size)
  55. size = int(32 * size)
  56. tensor.fill_(size)
  57. if is_distributed:
  58. synchronize()
  59. dist.broadcast(tensor, 0)
  60. input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None)
  61. return input_size