Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolox_x_mix_mot20_ch.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # encoding: utf-8
  2. import os
  3. import random
  4. import torch
  5. import torch.nn as nn
  6. import torch.distributed as dist
  7. from yolox.exp import Exp as MyExp
  8. from yolox.data import get_yolox_datadir
  9. class Exp(MyExp):
  10. def __init__(self):
  11. super(Exp, self).__init__()
  12. self.num_classes = 1
  13. self.depth = 1.33
  14. self.width = 1.25
  15. self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
  16. self.train_ann = "train.json"
  17. self.val_ann = "test.json" # change to train.json when running on training set
  18. self.input_size = (896, 1600)
  19. self.test_size = (896, 1600)
  20. #self.test_size = (736, 1920)
  21. self.random_size = (20, 36)
  22. self.max_epoch = 80
  23. self.print_interval = 20
  24. self.eval_interval = 5
  25. self.test_conf = 0.001
  26. self.nmsthre = 0.7
  27. self.no_aug_epochs = 10
  28. self.basic_lr_per_img = 0.001 / 64.0
  29. self.warmup_epochs = 1
  30. def get_data_loader(self, batch_size, is_distributed, no_aug=False):
  31. from yolox.data import (
  32. MOTDataset,
  33. TrainTransform,
  34. YoloBatchSampler,
  35. DataLoader,
  36. InfiniteSampler,
  37. MosaicDetection,
  38. )
  39. dataset = MOTDataset(
  40. data_dir=os.path.join(get_yolox_datadir(), "mix_mot20_ch"),
  41. json_file=self.train_ann,
  42. name='',
  43. img_size=self.input_size,
  44. preproc=TrainTransform(
  45. rgb_means=(0.485, 0.456, 0.406),
  46. std=(0.229, 0.224, 0.225),
  47. max_labels=600,
  48. ),
  49. )
  50. dataset = MosaicDetection(
  51. dataset,
  52. mosaic=not no_aug,
  53. img_size=self.input_size,
  54. preproc=TrainTransform(
  55. rgb_means=(0.485, 0.456, 0.406),
  56. std=(0.229, 0.224, 0.225),
  57. max_labels=1200,
  58. ),
  59. degrees=self.degrees,
  60. translate=self.translate,
  61. scale=self.scale,
  62. shear=self.shear,
  63. perspective=self.perspective,
  64. enable_mixup=self.enable_mixup,
  65. )
  66. self.dataset = dataset
  67. if is_distributed:
  68. batch_size = batch_size // dist.get_world_size()
  69. sampler = InfiniteSampler(
  70. len(self.dataset), seed=self.seed if self.seed else 0
  71. )
  72. batch_sampler = YoloBatchSampler(
  73. sampler=sampler,
  74. batch_size=batch_size,
  75. drop_last=False,
  76. input_dimension=self.input_size,
  77. mosaic=not no_aug,
  78. )
  79. dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
  80. dataloader_kwargs["batch_sampler"] = batch_sampler
  81. train_loader = DataLoader(self.dataset, **dataloader_kwargs)
  82. return train_loader
  83. def get_eval_loader(self, batch_size, is_distributed, testdev=False):
  84. from yolox.data import MOTDataset, ValTransform
  85. valdataset = MOTDataset(
  86. data_dir=os.path.join(get_yolox_datadir(), "MOT20"),
  87. json_file=self.val_ann,
  88. img_size=self.test_size,
  89. name='test', # change to train when running on training set
  90. preproc=ValTransform(
  91. rgb_means=(0.485, 0.456, 0.406),
  92. std=(0.229, 0.224, 0.225),
  93. ),
  94. )
  95. if is_distributed:
  96. batch_size = batch_size // dist.get_world_size()
  97. sampler = torch.utils.data.distributed.DistributedSampler(
  98. valdataset, shuffle=False
  99. )
  100. else:
  101. sampler = torch.utils.data.SequentialSampler(valdataset)
  102. dataloader_kwargs = {
  103. "num_workers": self.data_num_workers,
  104. "pin_memory": True,
  105. "sampler": sampler,
  106. }
  107. dataloader_kwargs["batch_size"] = batch_size
  108. val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
  109. return val_loader
  110. def get_evaluator(self, batch_size, is_distributed, testdev=False):
  111. from yolox.evaluators import COCOEvaluator
  112. val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)
  113. evaluator = COCOEvaluator(
  114. dataloader=val_loader,
  115. img_size=self.test_size,
  116. confthre=self.test_conf,
  117. nmsthre=self.nmsthre,
  118. num_classes=self.num_classes,
  119. testdev=testdev,
  120. )
  121. return evaluator