Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_yolox_base.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Mahdi Abdollahpour, 27/11/2021, 02:42 PM, PyCharm, ByteTrack
  2. import torch
  3. import torch.distributed as dist
  4. import torch.nn as nn
  5. import os
  6. import random
  7. from .base_meta_exp import BaseMetaExp
  8. import learn2learn as l2l
  9. class MetaExp(BaseMetaExp):
  10. def __init__(self):
  11. super().__init__()
  12. # ---------------- model config ---------------- #
  13. self.num_classes = 80
  14. self.depth = 1.00
  15. self.width = 1.00
  16. # ---------------- dataloader config ---------------- #
  17. # set worker to 4 for shorter dataloader init time
  18. # TODO: deal with this multi threading
  19. self.data_num_workers = 4
  20. self.input_size = (640, 640)
  21. self.random_size = (14, 26)
  22. self.train_anns = ["instances_train2017.json"]
  23. self.val_anns = ["instances_val2017.json"]
  24. # --------------- transform config ----------------- #
  25. self.degrees = 10.0
  26. self.translate = 0.1
  27. self.scale = (0.1, 2)
  28. self.mscale = (0.8, 1.6)
  29. self.shear = 2.0
  30. self.perspective = 0.0
  31. self.enable_mixup = True
  32. # -------------- training config --------------------- #
  33. self.warmup_epochs = 5
  34. self.max_epoch = 300
  35. self.warmup_lr = 0
  36. self.basic_lr_per_img = 0.01 / 64.0
  37. self.scheduler = "yoloxwarmcos"
  38. self.no_aug_epochs = 15
  39. self.min_lr_ratio = 0.05
  40. self.ema = True
  41. self.weight_decay = 5e-4
  42. self.momentum = 0.9
  43. self.print_interval = 10
  44. self.eval_interval = 10
  45. self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
  46. # ----------------- testing config ------------------ #
  47. self.test_size = (640, 640)
  48. self.test_conf = 0.001
  49. self.nmsthre = 0.65
  50. # ----------------- Meta-learning ------------------ #
  51. self.first_order = True
  52. self.inner_lr = 1e-5
  53. def get_model(self):
  54. from yolox.models import YOLOPAFPN, YOLOX, YOLOXHead
  55. def init_yolo(M):
  56. for m in M.modules():
  57. if isinstance(m, nn.BatchNorm2d):
  58. m.eps = 1e-3
  59. m.momentum = 0.03
  60. if getattr(self, "model", None) is None:
  61. in_channels = [256, 512, 1024]
  62. backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
  63. head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels)
  64. self.model = YOLOX(backbone, head)
  65. self.model.apply(init_yolo)
  66. self.model.head.initialize_biases(1e-2)
  67. return self.model
  68. def get_data_loaders(self, batch_size, is_distributed, no_aug=False):
  69. from yolox.data import (
  70. COCODataset,
  71. DataLoader,
  72. InfiniteSampler,
  73. MosaicDetection,
  74. TrainTransform,
  75. YoloBatchSampler
  76. )
  77. train_loaders = []
  78. for train_ann in self.train_anns:
  79. dataset = COCODataset(
  80. data_dir=None,
  81. json_file=train_ann,
  82. img_size=self.input_size,
  83. preproc=TrainTransform(
  84. rgb_means=(0.485, 0.456, 0.406),
  85. std=(0.229, 0.224, 0.225),
  86. max_labels=50,
  87. ),
  88. )
  89. dataset = MosaicDetection(
  90. dataset,
  91. mosaic=not no_aug,
  92. img_size=self.input_size,
  93. preproc=TrainTransform(
  94. rgb_means=(0.485, 0.456, 0.406),
  95. std=(0.229, 0.224, 0.225),
  96. max_labels=120,
  97. ),
  98. degrees=self.degrees,
  99. translate=self.translate,
  100. scale=self.scale,
  101. shear=self.shear,
  102. perspective=self.perspective,
  103. enable_mixup=self.enable_mixup,
  104. )
  105. self.dataset = dataset
  106. if is_distributed:
  107. batch_size = batch_size // dist.get_world_size()
  108. sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
  109. batch_sampler = YoloBatchSampler(
  110. sampler=sampler,
  111. batch_size=batch_size,
  112. drop_last=False,
  113. input_dimension=self.input_size,
  114. mosaic=not no_aug,
  115. )
  116. dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
  117. dataloader_kwargs["batch_sampler"] = batch_sampler
  118. train_loader = DataLoader(self.dataset, **dataloader_kwargs)
  119. train_loaders.append(train_loader)
  120. return train_loaders
  121. def random_resize(self, data_loader, epoch, rank, is_distributed):
  122. tensor = torch.LongTensor(2).cuda()
  123. if rank == 0:
  124. size_factor = self.input_size[1] * 1.0 / self.input_size[0]
  125. size = random.randint(*self.random_size)
  126. size = (int(32 * size), 32 * int(size * size_factor))
  127. tensor[0] = size[0]
  128. tensor[1] = size[1]
  129. if is_distributed:
  130. dist.barrier()
  131. dist.broadcast(tensor, 0)
  132. input_size = data_loader.change_input_dim(
  133. multiple=(tensor[0].item(), tensor[1].item()), random_range=None
  134. )
  135. return input_size
  136. def get_optimizer(self, batch_size):
  137. if "optimizer" not in self.__dict__:
  138. if self.warmup_epochs > 0:
  139. lr = self.warmup_lr
  140. else:
  141. lr = self.basic_lr_per_img * batch_size
  142. pg0, pg1, pg2 = [], [], [] # optimizer parameter groups
  143. for k, v in self.model.named_modules():
  144. if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
  145. pg2.append(v.bias) # biases
  146. if isinstance(v, nn.BatchNorm2d) or "bn" in k:
  147. pg0.append(v.weight) # no decay
  148. elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
  149. pg1.append(v.weight) # apply decay
  150. optimizer = torch.optim.SGD(
  151. pg0, lr=lr, momentum=self.momentum, nesterov=True
  152. )
  153. optimizer.add_param_group(
  154. {"params": pg1, "weight_decay": self.weight_decay}
  155. ) # add pg1 with weight_decay
  156. optimizer.add_param_group({"params": pg2})
  157. self.all_parameters = pg0 + pg1 + pg2
  158. self.optimizer = optimizer
  159. return self.optimizer
  160. def get_lr_scheduler(self, lr, iters_per_epoch):
  161. from yolox.utils import LRScheduler
  162. scheduler = LRScheduler(
  163. self.scheduler,
  164. lr,
  165. iters_per_epoch,
  166. self.max_epoch,
  167. warmup_epochs=self.warmup_epochs,
  168. warmup_lr_start=self.warmup_lr,
  169. no_aug_epochs=self.no_aug_epochs,
  170. min_lr_ratio=self.min_lr_ratio,
  171. )
  172. return scheduler
  173. def get_eval_loaders(self, batch_size, is_distributed, testdev=False):
  174. from yolox.data import COCODataset, ValTransform
  175. val_loaders = []
  176. for val_ann in self.val_anns:
  177. valdataset = COCODataset(
  178. data_dir=None,
  179. json_file=val_ann if not testdev else "image_info_test-dev2017.json",
  180. name="val2017" if not testdev else "test2017",
  181. img_size=self.test_size,
  182. preproc=ValTransform(
  183. rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
  184. ),
  185. )
  186. if is_distributed:
  187. batch_size = batch_size // dist.get_world_size()
  188. sampler = torch.utils.data.distributed.DistributedSampler(
  189. valdataset, shuffle=False
  190. )
  191. else:
  192. sampler = torch.utils.data.SequentialSampler(valdataset)
  193. dataloader_kwargs = {
  194. "num_workers": self.data_num_workers,
  195. "pin_memory": True,
  196. "sampler": sampler,
  197. }
  198. dataloader_kwargs["batch_size"] = batch_size
  199. val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
  200. val_loaders.append(val_loader)
  201. return val_loaders
  202. def get_evaluator(self, batch_size, is_distributed, testdev=False):
  203. from yolox.evaluators import COCOEvaluator
  204. val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
  205. evaluator = COCOEvaluator(
  206. dataloader=val_loader,
  207. img_size=self.test_size,
  208. confthre=self.test_conf,
  209. nmsthre=self.nmsthre,
  210. num_classes=self.num_classes,
  211. testdev=testdev,
  212. )
  213. return evaluator
  214. def eval(self, model, evaluator, is_distributed, half=False):
  215. return evaluator.evaluate(model, is_distributed, half)