#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii, Inc. and its affiliates. import os import torch import torch.nn as nn from yolox.exp import Exp as MyExp class Exp(MyExp): def __init__(self): super(Exp, self).__init__() self.depth = 1.0 self.width = 1.0 self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] def get_model(self, sublinear=False): def init_yolo(M): for m in M.modules(): if isinstance(m, nn.BatchNorm2d): m.eps = 1e-3 m.momentum = 0.03 if "model" not in self.__dict__: from yolox.models import YOLOX, YOLOFPN, YOLOXHead backbone = YOLOFPN() head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu") self.model = YOLOX(backbone, head) self.model.apply(init_yolo) self.model.head.initialize_biases(1e-2) return self.model def get_data_loader(self, batch_size, is_distributed, no_aug=False): from data.datasets.cocodataset import COCODataset from data.datasets.mosaicdetection import MosaicDetection from data.datasets.data_augment import TrainTransform from data.datasets.dataloading import YoloBatchSampler, DataLoader, InfiniteSampler import torch.distributed as dist dataset = COCODataset( data_dir='data/COCO/', json_file=self.train_ann, img_size=self.input_size, preproc=TrainTransform( rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_labels=50 ), ) dataset = MosaicDetection( dataset, mosaic=not no_aug, img_size=self.input_size, preproc=TrainTransform( rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_labels=120 ), degrees=self.degrees, translate=self.translate, scale=self.scale, shear=self.shear, perspective=self.perspective, ) self.dataset = dataset if is_distributed: batch_size = batch_size // dist.get_world_size() sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0) else: sampler = torch.utils.data.RandomSampler(self.dataset) batch_sampler = YoloBatchSampler( sampler=sampler, batch_size=batch_size, drop_last=False, input_dimension=self.input_size, mosaic=not no_aug ) dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True} dataloader_kwargs["batch_sampler"] = batch_sampler train_loader = DataLoader(self.dataset, **dataloader_kwargs) return train_loader