Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolov3.py 2.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. import os
  5. import torch
  6. import torch.nn as nn
  7. from yolox.exp import Exp as MyExp
  8. class Exp(MyExp):
  9. def __init__(self):
  10. super(Exp, self).__init__()
  11. self.depth = 1.0
  12. self.width = 1.0
  13. self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
  14. def get_model(self, sublinear=False):
  15. def init_yolo(M):
  16. for m in M.modules():
  17. if isinstance(m, nn.BatchNorm2d):
  18. m.eps = 1e-3
  19. m.momentum = 0.03
  20. if "model" not in self.__dict__:
  21. from yolox.models import YOLOX, YOLOFPN, YOLOXHead
  22. backbone = YOLOFPN()
  23. head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
  24. self.model = YOLOX(backbone, head)
  25. self.model.apply(init_yolo)
  26. self.model.head.initialize_biases(1e-2)
  27. return self.model
  28. def get_data_loader(self, batch_size, is_distributed, no_aug=False):
  29. from data.datasets.cocodataset import COCODataset
  30. from data.datasets.mosaicdetection import MosaicDetection
  31. from data.datasets.data_augment import TrainTransform
  32. from data.datasets.dataloading import YoloBatchSampler, DataLoader, InfiniteSampler
  33. import torch.distributed as dist
  34. dataset = COCODataset(
  35. data_dir='data/COCO/',
  36. json_file=self.train_ann,
  37. img_size=self.input_size,
  38. preproc=TrainTransform(
  39. rgb_means=(0.485, 0.456, 0.406),
  40. std=(0.229, 0.224, 0.225),
  41. max_labels=50
  42. ),
  43. )
  44. dataset = MosaicDetection(
  45. dataset,
  46. mosaic=not no_aug,
  47. img_size=self.input_size,
  48. preproc=TrainTransform(
  49. rgb_means=(0.485, 0.456, 0.406),
  50. std=(0.229, 0.224, 0.225),
  51. max_labels=120
  52. ),
  53. degrees=self.degrees,
  54. translate=self.translate,
  55. scale=self.scale,
  56. shear=self.shear,
  57. perspective=self.perspective,
  58. )
  59. self.dataset = dataset
  60. if is_distributed:
  61. batch_size = batch_size // dist.get_world_size()
  62. sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
  63. else:
  64. sampler = torch.utils.data.RandomSampler(self.dataset)
  65. batch_sampler = YoloBatchSampler(
  66. sampler=sampler,
  67. batch_size=batch_size,
  68. drop_last=False,
  69. input_dimension=self.input_size,
  70. mosaic=not no_aug
  71. )
  72. dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
  73. dataloader_kwargs["batch_sampler"] = batch_sampler
  74. train_loader = DataLoader(self.dataset, **dataloader_kwargs)
  75. return train_loader