Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox_base.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.distributed as dist
  6. import torch.nn as nn
  7. import os
  8. import random
  9. from .base_exp import BaseExp
  10. class Exp(BaseExp):
  11. def __init__(self):
  12. super().__init__()
  13. # ---------------- model config ---------------- #
  14. self.num_classes = 80
  15. self.depth = 1.00
  16. self.width = 1.00
  17. # ---------------- dataloader config ---------------- #
  18. # set worker to 4 for shorter dataloader init time
  19. self.data_num_workers = 4
  20. self.input_size = (640, 640)
  21. self.random_size = (14, 26)
  22. self.train_ann = "instances_train2017.json"
  23. self.val_ann = "instances_val2017.json"
  24. # --------------- transform config ----------------- #
  25. self.degrees = 10.0
  26. self.translate = 0.1
  27. self.scale = (0.1, 2)
  28. self.mscale = (0.8, 1.6)
  29. self.shear = 2.0
  30. self.perspective = 0.0
  31. self.enable_mixup = True
  32. # -------------- training config --------------------- #
  33. self.warmup_epochs = 5
  34. self.max_epoch = 300
  35. self.warmup_lr = 0
  36. self.basic_lr_per_img = 0.01 / 64.0
  37. self.scheduler = "yoloxwarmcos"
  38. self.no_aug_epochs = 15
  39. self.min_lr_ratio = 0.05
  40. self.ema = True
  41. self.weight_decay = 5e-4
  42. self.momentum = 0.9
  43. self.print_interval = 10
  44. self.eval_interval = 10
  45. self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
  46. # ----------------- testing config ------------------ #
  47. self.test_size = (640, 640)
  48. self.test_conf = 0.001
  49. self.nmsthre = 0.65
  50. def get_model(self):
  51. from yolox.models import YOLOPAFPN, YOLOX, YOLOXHead
  52. def init_yolo(M):
  53. for m in M.modules():
  54. if isinstance(m, nn.BatchNorm2d):
  55. m.eps = 1e-3
  56. m.momentum = 0.03
  57. if getattr(self, "model", None) is None:
  58. in_channels = [256, 512, 1024]
  59. backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
  60. head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels)
  61. self.model = YOLOX(backbone, head)
  62. self.model.apply(init_yolo)
  63. self.model.head.initialize_biases(1e-2)
  64. return self.model
  65. def get_data_loader(self, batch_size, is_distributed, no_aug=False):
  66. from yolox.data import (
  67. COCODataset,
  68. DataLoader,
  69. InfiniteSampler,
  70. MosaicDetection,
  71. TrainTransform,
  72. YoloBatchSampler
  73. )
  74. dataset = COCODataset(
  75. data_dir=None,
  76. json_file=self.train_ann,
  77. img_size=self.input_size,
  78. preproc=TrainTransform(
  79. rgb_means=(0.485, 0.456, 0.406),
  80. std=(0.229, 0.224, 0.225),
  81. max_labels=50,
  82. ),
  83. )
  84. dataset = MosaicDetection(
  85. dataset,
  86. mosaic=not no_aug,
  87. img_size=self.input_size,
  88. preproc=TrainTransform(
  89. rgb_means=(0.485, 0.456, 0.406),
  90. std=(0.229, 0.224, 0.225),
  91. max_labels=120,
  92. ),
  93. degrees=self.degrees,
  94. translate=self.translate,
  95. scale=self.scale,
  96. shear=self.shear,
  97. perspective=self.perspective,
  98. enable_mixup=self.enable_mixup,
  99. )
  100. self.dataset = dataset
  101. if is_distributed:
  102. batch_size = batch_size // dist.get_world_size()
  103. sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
  104. batch_sampler = YoloBatchSampler(
  105. sampler=sampler,
  106. batch_size=batch_size,
  107. drop_last=False,
  108. input_dimension=self.input_size,
  109. mosaic=not no_aug,
  110. )
  111. dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
  112. dataloader_kwargs["batch_sampler"] = batch_sampler
  113. train_loader = DataLoader(self.dataset, **dataloader_kwargs)
  114. return train_loader
  115. def random_resize(self, data_loader, epoch, rank, is_distributed):
  116. tensor = torch.LongTensor(2).cuda()
  117. if rank == 0:
  118. size_factor = self.input_size[1] * 1.0 / self.input_size[0]
  119. size = random.randint(*self.random_size)
  120. size = (int(32 * size), 32 * int(size * size_factor))
  121. tensor[0] = size[0]
  122. tensor[1] = size[1]
  123. if is_distributed:
  124. dist.barrier()
  125. dist.broadcast(tensor, 0)
  126. input_size = data_loader.change_input_dim(
  127. multiple=(tensor[0].item(), tensor[1].item()), random_range=None
  128. )
  129. return input_size
  130. def get_optimizer(self, batch_size):
  131. if "optimizer" not in self.__dict__:
  132. if self.warmup_epochs > 0:
  133. lr = self.warmup_lr
  134. else:
  135. lr = self.basic_lr_per_img * batch_size
  136. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  137. for k, v in self.model.named_modules():
  138. if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
  139. pg2.append(v.bias) # biases
  140. if isinstance(v, nn.BatchNorm2d) or "bn" in k:
  141. pg0.append(v.weight) # no decay
  142. elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
  143. pg1.append(v.weight) # apply decay
  144. optimizer = torch.optim.SGD(
  145. pg0, lr=lr, momentum=self.momentum, nesterov=True
  146. )
  147. optimizer.add_param_group(
  148. {"params": pg1, "weight_decay": self.weight_decay}
  149. ) # add pg1 with weight_decay
  150. optimizer.add_param_group({"params": pg2})
  151. self.optimizer = optimizer
  152. return self.optimizer
  153. def get_lr_scheduler(self, lr, iters_per_epoch):
  154. from yolox.utils import LRScheduler
  155. scheduler = LRScheduler(
  156. self.scheduler,
  157. lr,
  158. iters_per_epoch,
  159. self.max_epoch,
  160. warmup_epochs=self.warmup_epochs,
  161. warmup_lr_start=self.warmup_lr,
  162. no_aug_epochs=self.no_aug_epochs,
  163. min_lr_ratio=self.min_lr_ratio,
  164. )
  165. return scheduler
  166. def get_eval_loader(self, batch_size, is_distributed, testdev=False):
  167. from yolox.data import COCODataset, ValTransform
  168. valdataset = COCODataset(
  169. data_dir=None,
  170. json_file=self.val_ann if not testdev else "image_info_test-dev2017.json",
  171. name="val2017" if not testdev else "test2017",
  172. img_size=self.test_size,
  173. preproc=ValTransform(
  174. rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
  175. ),
  176. )
  177. if is_distributed:
  178. batch_size = batch_size // dist.get_world_size()
  179. sampler = torch.utils.data.distributed.DistributedSampler(
  180. valdataset, shuffle=False
  181. )
  182. else:
  183. sampler = torch.utils.data.SequentialSampler(valdataset)
  184. dataloader_kwargs = {
  185. "num_workers": self.data_num_workers,
  186. "pin_memory": True,
  187. "sampler": sampler,
  188. }
  189. dataloader_kwargs["batch_size"] = batch_size
  190. val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
  191. return val_loader
  192. def get_evaluator(self, batch_size, is_distributed, testdev=False):
  193. from yolox.evaluators import COCOEvaluator
  194. val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
  195. evaluator = COCOEvaluator(
  196. dataloader=val_loader,
  197. img_size=self.test_size,
  198. confthre=self.test_conf,
  199. nmsthre=self.nmsthre,
  200. num_classes=self.num_classes,
  201. testdev=testdev,
  202. )
  203. return evaluator
  204. def eval(self, model, evaluator, is_distributed, half=False):
  205. return evaluator.evaluate(model, is_distributed, half)