Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox_x_mot17_on_mot20.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # encoding: utf-8
  2. import os
  3. import random
  4. import torch
  5. import torch.nn as nn
  6. import torch.distributed as dist
  7. from yolox.exp import MetaExp as MyMetaExp
  8. from yolox.data import get_yolox_datadir
  9. from os import listdir
  10. from os.path import isfile, join
  11. class Exp(MyMetaExp):
  12. def __init__(self):
  13. super(Exp, self).__init__()
  14. self.num_classes = 1
  15. self.depth = 1.33
  16. self.width = 1.25
  17. self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
  18. self.train_dir = '/home/abdollahpour.ce.sharif/ByteTrackData/MOT17/annotations'
  19. onlyfiles = [f for f in listdir(self.train_dir) if isfile(join(self.train_dir, f))]
  20. self.train_anns = [file for file in onlyfiles if file.__contains__('train') and file.__contains__('FRCNN')]
  21. # # TODO: remove
  22. # self.train_anns = self.train_anns[3:]
  23. self.val_dir = '/home/abdollahpour.ce.sharif/ByteTrackData/MOT20/annotations'
  24. onlyfiles = [f for f in listdir(self.val_dir) if isfile(join(self.val_dir, f))]
  25. self.val_anns = [file for file in onlyfiles if file.__contains__('train') and file.__contains__(
  26. 'MOT20')]
  27. # self.val_anns = self.val_anns[-1:]
  28. print('train_anns', self.train_anns)
  29. print('val_anns', self.val_anns)
  30. self.input_size = (800, 1440)
  31. # TODO: try this
  32. self.test_size = (800, 1440)
  33. # self.test_size = (896, 1600)
  34. # self.test_size = (736, 1920)
  35. self.random_size = (20, 36)
  36. self.max_epoch = 80
  37. self.print_interval = 250
  38. self.eval_interval = 10
  39. self.test_conf = 0.001
  40. self.nmsthre = 0.7
  41. self.no_aug_epochs = 10
  42. # self.basic_lr_per_img = 0.001 / 64.0
  43. self.basic_lr_per_img = 0.0001 / 64.0
  44. self.warmup_epochs = 1
  45. def get_data_loaders(self, batch_size, is_distributed, no_aug=False):
  46. from yolox.data import (
  47. MOTDataset,
  48. TrainTransform,
  49. YoloBatchSampler,
  50. DataLoader,
  51. InfiniteSampler,
  52. MosaicDetection,
  53. )
  54. train_loaders = []
  55. for train_ann in self.train_anns:
  56. dataset = MOTDataset(
  57. data_dir=os.path.join(get_yolox_datadir(), "MOT17"),
  58. json_file=train_ann,
  59. name='train',
  60. img_size=self.input_size,
  61. preproc=TrainTransform(
  62. rgb_means=(0.485, 0.456, 0.406),
  63. std=(0.229, 0.224, 0.225),
  64. max_labels=500,
  65. ),
  66. )
  67. dataset = MosaicDetection(
  68. dataset,
  69. mosaic=not no_aug,
  70. img_size=self.input_size,
  71. preproc=TrainTransform(
  72. rgb_means=(0.485, 0.456, 0.406),
  73. std=(0.229, 0.224, 0.225),
  74. max_labels=1000,
  75. ),
  76. degrees=self.degrees,
  77. translate=self.translate,
  78. scale=self.scale,
  79. shear=self.shear,
  80. perspective=self.perspective,
  81. enable_mixup=self.enable_mixup,
  82. )
  83. self.dataset = dataset
  84. if is_distributed:
  85. batch_size = batch_size // dist.get_world_size()
  86. sampler = InfiniteSampler(
  87. len(self.dataset), seed=self.seed if self.seed else 0
  88. )
  89. batch_sampler = YoloBatchSampler(
  90. sampler=sampler,
  91. batch_size=batch_size,
  92. drop_last=False,
  93. input_dimension=self.input_size,
  94. mosaic=not no_aug,
  95. )
  96. dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
  97. dataloader_kwargs["batch_sampler"] = batch_sampler
  98. train_loader = DataLoader(self.dataset, **dataloader_kwargs)
  99. train_loaders.append(train_loader)
  100. return train_loaders
  101. def get_eval_loaders(self, batch_size, is_distributed, testdev=False):
  102. from yolox.data import MOTDataset, ValTransform, ValTransformWithPseudo, TrainTransform
  103. val_loaders = []
  104. for val_ann in self.val_anns:
  105. valdataset = MOTDataset(
  106. data_dir=os.path.join(get_yolox_datadir(), "MOT20"),
  107. json_file=val_ann,
  108. img_size=self.test_size,
  109. name='train', # change to train when running on training set
  110. preproc=ValTransformWithPseudo(
  111. rgb_means=(0.485, 0.456, 0.406),
  112. std=(0.229, 0.224, 0.225), max_labels=500,
  113. ),
  114. load_weak=False
  115. )
  116. if is_distributed:
  117. batch_size = batch_size // dist.get_world_size()
  118. sampler = torch.utils.data.distributed.DistributedSampler(
  119. valdataset, shuffle=False
  120. )
  121. else:
  122. sampler = torch.utils.data.SequentialSampler(valdataset)
  123. dataloader_kwargs = {
  124. "num_workers": self.data_num_workers,
  125. "pin_memory": True,
  126. "sampler": sampler,
  127. }
  128. dataloader_kwargs["batch_size"] = batch_size
  129. val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
  130. val_loaders.append(val_loader)
  131. return val_loaders
  132. def get_evaluators(self, batch_size, is_distributed, testdev=False):
  133. from yolox.evaluators import COCOEvaluator
  134. val_loaders = self.get_eval_loaders(batch_size, is_distributed, testdev=testdev)
  135. evaluators = []
  136. for val_loader in val_loaders:
  137. evaluator = COCOEvaluator(
  138. dataloader=val_loader,
  139. img_size=self.test_size,
  140. confthre=self.test_conf,
  141. nmsthre=self.nmsthre,
  142. num_classes=self.num_classes,
  143. testdev=testdev,
  144. )
  145. evaluators.append(evaluator)
  146. return evaluators