Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox_nano_mix_det.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # encoding: utf-8
  2. import os
  3. import random
  4. import torch
  5. import torch.nn as nn
  6. import torch.distributed as dist
  7. from yolox.exp import Exp as MyExp
  8. from yolox.data import get_yolox_datadir
  9. class Exp(MyExp):
  10. def __init__(self):
  11. super(Exp, self).__init__()
  12. self.num_classes = 1
  13. self.depth = 0.33
  14. self.width = 0.25
  15. self.scale = (0.5, 1.5)
  16. self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
  17. self.train_ann = "train.json"
  18. self.val_ann = "train.json"
  19. self.input_size = (608, 1088)
  20. self.test_size = (608, 1088)
  21. self.random_size = (12, 26)
  22. self.max_epoch = 80
  23. self.print_interval = 20
  24. self.eval_interval = 5
  25. self.test_conf = 0.001
  26. self.nmsthre = 0.7
  27. self.no_aug_epochs = 10
  28. self.basic_lr_per_img = 0.001 / 64.0
  29. self.warmup_epochs = 1
  30. def get_model(self, sublinear=False):
  31. def init_yolo(M):
  32. for m in M.modules():
  33. if isinstance(m, nn.BatchNorm2d):
  34. m.eps = 1e-3
  35. m.momentum = 0.03
  36. if "model" not in self.__dict__:
  37. from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
  38. in_channels = [256, 512, 1024]
  39. # NANO model use depthwise = True, which is main difference.
  40. backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, depthwise=True)
  41. head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, depthwise=True)
  42. self.model = YOLOX(backbone, head)
  43. self.model.apply(init_yolo)
  44. self.model.head.initialize_biases(1e-2)
  45. return self.model
  46. def get_data_loader(self, batch_size, is_distributed, no_aug=False):
  47. from yolox.data import (
  48. MOTDataset,
  49. TrainTransform,
  50. YoloBatchSampler,
  51. DataLoader,
  52. InfiniteSampler,
  53. MosaicDetection,
  54. )
  55. dataset = MOTDataset(
  56. data_dir=os.path.join(get_yolox_datadir(), "mix_det"),
  57. json_file=self.train_ann,
  58. name='',
  59. img_size=self.input_size,
  60. preproc=TrainTransform(
  61. rgb_means=(0.485, 0.456, 0.406),
  62. std=(0.229, 0.224, 0.225),
  63. max_labels=500,
  64. ),
  65. )
  66. dataset = MosaicDetection(
  67. dataset,
  68. mosaic=not no_aug,
  69. img_size=self.input_size,
  70. preproc=TrainTransform(
  71. rgb_means=(0.485, 0.456, 0.406),
  72. std=(0.229, 0.224, 0.225),
  73. max_labels=1000,
  74. ),
  75. degrees=self.degrees,
  76. translate=self.translate,
  77. scale=self.scale,
  78. shear=self.shear,
  79. perspective=self.perspective,
  80. enable_mixup=self.enable_mixup,
  81. )
  82. self.dataset = dataset
  83. if is_distributed:
  84. batch_size = batch_size // dist.get_world_size()
  85. sampler = InfiniteSampler(
  86. len(self.dataset), seed=self.seed if self.seed else 0
  87. )
  88. batch_sampler = YoloBatchSampler(
  89. sampler=sampler,
  90. batch_size=batch_size,
  91. drop_last=False,
  92. input_dimension=self.input_size,
  93. mosaic=not no_aug,
  94. )
  95. dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
  96. dataloader_kwargs["batch_sampler"] = batch_sampler
  97. train_loader = DataLoader(self.dataset, **dataloader_kwargs)
  98. return train_loader
  99. def get_eval_loader(self, batch_size, is_distributed, testdev=False):
  100. from yolox.data import MOTDataset, ValTransform
  101. valdataset = MOTDataset(
  102. data_dir=os.path.join(get_yolox_datadir(), "mot"),
  103. json_file=self.val_ann,
  104. img_size=self.test_size,
  105. name='train',
  106. preproc=ValTransform(
  107. rgb_means=(0.485, 0.456, 0.406),
  108. std=(0.229, 0.224, 0.225),
  109. ),
  110. )
  111. if is_distributed:
  112. batch_size = batch_size // dist.get_world_size()
  113. sampler = torch.utils.data.distributed.DistributedSampler(
  114. valdataset, shuffle=False
  115. )
  116. else:
  117. sampler = torch.utils.data.SequentialSampler(valdataset)
  118. dataloader_kwargs = {
  119. "num_workers": self.data_num_workers,
  120. "pin_memory": True,
  121. "sampler": sampler,
  122. }
  123. dataloader_kwargs["batch_size"] = batch_size
  124. val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
  125. return val_loader
  126. def get_evaluator(self, batch_size, is_distributed, testdev=False):
  127. from yolox.evaluators import COCOEvaluator
  128. val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
  129. evaluator = COCOEvaluator(
  130. dataloader=val_loader,
  131. img_size=self.test_size,
  132. confthre=self.test_conf,
  133. nmsthre=self.nmsthre,
  134. num_classes=self.num_classes,
  135. testdev=testdev,
  136. )
  137. return evaluator