|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) Megvii, Inc. and its affiliates.
-
- import torch
- import torch.distributed as dist
-
- from yolox.utils import synchronize
-
- import random
-
-
- class DataPrefetcher:
- """
- DataPrefetcher is inspired by code of following file:
- https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
- It could speedup your pytorch dataloader. For more information, please check
- https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
- """
-
- def __init__(self, loader):
- self.loader = iter(loader)
- self.stream = torch.cuda.Stream()
- self.input_cuda = self._input_cuda_for_image
- self.record_stream = DataPrefetcher._record_stream_for_image
- self.preload()
-
- def preload(self):
- try:
- self.next_input, self.next_target, _, _ = next(self.loader)
- except StopIteration:
- self.next_input = None
- self.next_target = None
- return
-
- with torch.cuda.stream(self.stream):
- self.input_cuda()
- self.next_target = self.next_target.cuda(non_blocking=True)
-
- def next(self):
- torch.cuda.current_stream().wait_stream(self.stream)
- input = self.next_input
- target = self.next_target
- if input is not None:
- self.record_stream(input)
- if target is not None:
- target.record_stream(torch.cuda.current_stream())
- self.preload()
- return input, target
-
- def _input_cuda_for_image(self):
- self.next_input = self.next_input.cuda(non_blocking=True)
-
- @staticmethod
- def _record_stream_for_image(input):
- input.record_stream(torch.cuda.current_stream())
-
-
- def random_resize(data_loader, exp, epoch, rank, is_distributed):
- tensor = torch.LongTensor(1).cuda()
- if is_distributed:
- synchronize()
-
- if rank == 0:
- if epoch > exp.max_epoch - 10:
- size = exp.input_size
- else:
- size = random.randint(*exp.random_size)
- size = int(32 * size)
- tensor.fill_(size)
-
- if is_distributed:
- synchronize()
- dist.broadcast(tensor, 0)
-
- input_size = data_loader.change_input_dim(multiple=tensor.item(), random_range=None)
- return input_size
|