1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) Megvii, Inc. and its affiliates.
-
- import os
- import torch
- import torch.nn as nn
-
- from yolox.exp import Exp as MyExp
-
-
- class Exp(MyExp):
- def __init__(self):
- super(Exp, self).__init__()
- self.depth = 1.0
- self.width = 1.0
- self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
-
- def get_model(self, sublinear=False):
- def init_yolo(M):
- for m in M.modules():
- if isinstance(m, nn.BatchNorm2d):
- m.eps = 1e-3
- m.momentum = 0.03
- if "model" not in self.__dict__:
- from yolox.models import YOLOX, YOLOFPN, YOLOXHead
- backbone = YOLOFPN()
- head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
- self.model = YOLOX(backbone, head)
- self.model.apply(init_yolo)
- self.model.head.initialize_biases(1e-2)
-
- return self.model
-
- def get_data_loader(self, batch_size, is_distributed, no_aug=False):
- from data.datasets.cocodataset import COCODataset
- from data.datasets.mosaicdetection import MosaicDetection
- from data.datasets.data_augment import TrainTransform
- from data.datasets.dataloading import YoloBatchSampler, DataLoader, InfiniteSampler
- import torch.distributed as dist
-
- dataset = COCODataset(
- data_dir='data/COCO/',
- json_file=self.train_ann,
- img_size=self.input_size,
- preproc=TrainTransform(
- rgb_means=(0.485, 0.456, 0.406),
- std=(0.229, 0.224, 0.225),
- max_labels=50
- ),
- )
-
- dataset = MosaicDetection(
- dataset,
- mosaic=not no_aug,
- img_size=self.input_size,
- preproc=TrainTransform(
- rgb_means=(0.485, 0.456, 0.406),
- std=(0.229, 0.224, 0.225),
- max_labels=120
- ),
- degrees=self.degrees,
- translate=self.translate,
- scale=self.scale,
- shear=self.shear,
- perspective=self.perspective,
- )
-
- self.dataset = dataset
-
- if is_distributed:
- batch_size = batch_size // dist.get_world_size()
- sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
- else:
- sampler = torch.utils.data.RandomSampler(self.dataset)
-
- batch_sampler = YoloBatchSampler(
- sampler=sampler,
- batch_size=batch_size,
- drop_last=False,
- input_dimension=self.input_size,
- mosaic=not no_aug
- )
-
- dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
- dataloader_kwargs["batch_sampler"] = batch_sampler
- train_loader = DataLoader(self.dataset, **dataloader_kwargs)
-
- return train_loader
|