# encoding: utf-8 import os import random import torch import torch.nn as nn import torch.distributed as dist from yolox.exp import MetaExp as MyMetaExp from yolox.data import get_yolox_datadir from os import listdir from os.path import isfile, join class Exp(MyMetaExp): def __init__(self): super(Exp, self).__init__() self.num_classes = 1 self.depth = 1.33 self.width = 1.25 self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] self.train_dir = '/home/abdollahpour.ce.sharif/ByteTrackData/MOT17/annotations' onlyfiles = [f for f in listdir(self.train_dir) if isfile(join(self.train_dir, f))] self.train_anns = [file for file in onlyfiles if file.__contains__('train') and file.__contains__('FRCNN')] # # TODO: remove # self.train_anns = self.train_anns[3:] self.val_dir = '/home/abdollahpour.ce.sharif/ByteTrackData/MOT20/annotations' onlyfiles = [f for f in listdir(self.val_dir) if isfile(join(self.val_dir, f))] self.val_anns = [file for file in onlyfiles if file.__contains__('train') and file.__contains__( 'MOT20')] print('train_anns', self.train_anns) print('val_anns', self.val_anns) self.input_size = (800, 1440) self.test_size = (896, 1600) # self.test_size = (736, 1920) self.random_size = (20, 36) self.max_epoch = 80 self.print_interval = 100 self.eval_interval = 5 self.test_conf = 0.001 self.nmsthre = 0.7 self.no_aug_epochs = 10 self.basic_lr_per_img = 0.001 / 64.0 self.warmup_epochs = 1 def get_data_loaders(self, batch_size, is_distributed, no_aug=False): from yolox.data import ( MOTDataset, TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, MosaicDetection, ) train_loaders = [] for train_ann in self.train_anns: dataset = MOTDataset( data_dir=os.path.join(get_yolox_datadir(), "MOT17"), json_file=train_ann, name='train', img_size=self.input_size, preproc=TrainTransform( rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_labels=500, ), ) dataset = MosaicDetection( dataset, mosaic=not no_aug, img_size=self.input_size, preproc=TrainTransform( rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_labels=1000, ), degrees=self.degrees, translate=self.translate, scale=self.scale, shear=self.shear, perspective=self.perspective, enable_mixup=self.enable_mixup, ) self.dataset = dataset if is_distributed: batch_size = batch_size // dist.get_world_size() sampler = InfiniteSampler( len(self.dataset), seed=self.seed if self.seed else 0 ) batch_sampler = YoloBatchSampler( sampler=sampler, batch_size=batch_size, drop_last=False, input_dimension=self.input_size, mosaic=not no_aug, ) dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True} dataloader_kwargs["batch_sampler"] = batch_sampler train_loader = DataLoader(self.dataset, **dataloader_kwargs) train_loaders.append(train_loader) return train_loaders def get_eval_loaders(self, batch_size, is_distributed, testdev=False): from yolox.data import MOTDataset, ValTransform val_loaders = [] for val_ann in self.val_anns: valdataset = MOTDataset( data_dir=os.path.join(get_yolox_datadir(), "MOT20"), json_file=val_ann, img_size=self.test_size, name='train', # change to train when running on training set preproc=ValTransform( rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ) if is_distributed: batch_size = batch_size // dist.get_world_size() sampler = torch.utils.data.distributed.DistributedSampler( valdataset, shuffle=False ) else: sampler = torch.utils.data.SequentialSampler(valdataset) dataloader_kwargs = { "num_workers": self.data_num_workers, "pin_memory": True, "sampler": sampler, } dataloader_kwargs["batch_size"] = batch_size val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs) val_loaders.append(val_loader) return val_loaders def get_evaluator(self, batch_size, is_distributed, testdev=False): from yolox.evaluators import COCOEvaluator val_loader = self.get_eval_loaders(batch_size, is_distributed, testdev=testdev) evaluator = COCOEvaluator( dataloader=val_loader, img_size=self.test_size, confthre=self.test_conf, nmsthre=self.nmsthre, num_classes=self.num_classes, testdev=testdev, ) return evaluator