123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- # ------------------------------------------------------------------------
- # Copyright (c) 2021 megvii-model. All Rights Reserved.
- # ------------------------------------------------------------------------
- # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
- # Copyright (c) 2020 SenseTime. All Rights Reserved.
- # ------------------------------------------------------------------------
- # Modified from DETR (https://github.com/facebookresearch/detr)
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- # ------------------------------------------------------------------------
-
- """
- MOT dataset which returns image_id for evaluation.
- """
- from pathlib import Path
- import cv2
- import numpy as np
- import torch
- import torch.utils.data
- import os.path as osp
- from PIL import Image, ImageDraw
- import copy
- import datasets.transforms as T
- from models.structures import Instances
-
-
- class DetMOTDetection:
- def __init__(self, args, data_txt_path: str, seqs_folder, dataset2transform):
- self.args = args
- self.dataset2transform = dataset2transform
- self.num_frames_per_batch = max(args.sampler_lengths)
- self.sample_mode = args.sample_mode
- self.sample_interval = args.sample_interval
- self.vis = args.vis
- self.video_dict = {}
-
- with open(data_txt_path, 'r') as file:
- self.img_files = file.readlines()
- self.img_files = [osp.join(seqs_folder, x.strip()) for x in self.img_files]
- self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
-
- self.label_files = [(x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt'))
- for x in self.img_files]
- # The number of images per sample: 1 + (num_frames - 1) * interval.
- # The number of valid samples: num_images - num_image_per_sample + 1.
- self.item_num = len(self.img_files) - (self.num_frames_per_batch - 1) * self.sample_interval
-
- self._register_videos()
-
- # video sampler.
- self.sampler_steps: list = args.sampler_steps
- self.lengths: list = args.sampler_lengths
- print("sampler_steps={} lenghts={}".format(self.sampler_steps, self.lengths))
- if self.sampler_steps is not None and len(self.sampler_steps) > 0:
- # Enable sampling length adjustment.
- assert len(self.lengths) > 0
- assert len(self.lengths) == len(self.sampler_steps) + 1
- for i in range(len(self.sampler_steps) - 1):
- assert self.sampler_steps[i] < self.sampler_steps[i + 1]
- self.item_num = len(self.img_files) - (self.lengths[-1] - 1) * self.sample_interval
- self.period_idx = 0
- self.num_frames_per_batch = self.lengths[0]
- self.current_epoch = 0
-
- def _register_videos(self):
- for label_name in self.label_files:
- video_name = '/'.join(label_name.split('/')[:-1])
- if video_name not in self.video_dict:
- print("register {}-th video: {} ".format(len(self.video_dict) + 1, video_name))
- self.video_dict[video_name] = len(self.video_dict)
- # assert len(self.video_dict) <= 300
-
- def set_epoch(self, epoch):
- self.current_epoch = epoch
- if self.sampler_steps is None or len(self.sampler_steps) == 0:
- # fixed sampling length.
- return
-
- for i in range(len(self.sampler_steps)):
- if epoch >= self.sampler_steps[i]:
- self.period_idx = i + 1
- print("set epoch: epoch {} period_idx={}".format(epoch, self.period_idx))
- self.num_frames_per_batch = self.lengths[self.period_idx]
-
- def step_epoch(self):
- # one epoch finishes.
- print("Dataset: epoch {} finishes".format(self.current_epoch))
- self.set_epoch(self.current_epoch + 1)
-
- @staticmethod
- def _targets_to_instances(targets: dict, img_shape) -> Instances:
- gt_instances = Instances(tuple(img_shape))
- gt_instances.boxes = targets['boxes']
- gt_instances.labels = targets['labels']
- gt_instances.obj_ids = targets['obj_ids']
- gt_instances.area = targets['area']
- return gt_instances
-
- def _pre_single_frame(self, idx: int):
- img_path = self.img_files[idx]
- label_path = self.label_files[idx]
- if 'crowdhuman' in img_path:
- img_path = img_path.replace('.jpg', '.png')
- img = Image.open(img_path)
- targets = {}
- w, h = img._size
- assert w > 0 and h > 0, "invalid image {} with shape {} {}".format(img_path, w, h)
- if osp.isfile(label_path):
- labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6)
-
- # normalized cewh to pixel xyxy format
- labels = labels0.copy()
- labels[:, 2] = w * (labels0[:, 2] - labels0[:, 4] / 2)
- labels[:, 3] = h * (labels0[:, 3] - labels0[:, 5] / 2)
- labels[:, 4] = w * (labels0[:, 2] + labels0[:, 4] / 2)
- labels[:, 5] = h * (labels0[:, 3] + labels0[:, 5] / 2)
- else:
- raise ValueError('invalid label path: {}'.format(label_path))
- video_name = '/'.join(label_path.split('/')[:-1])
- obj_idx_offset = self.video_dict[video_name] * 1000000 # 1000000 unique ids is enough for a video.
- if 'crowdhuman' in img_path:
- targets['dataset'] = 'CrowdHuman'
- elif 'MOT17' in img_path:
- targets['dataset'] = 'MOT17'
- else:
- raise NotImplementedError()
- targets['boxes'] = []
- targets['area'] = []
- targets['iscrowd'] = []
- targets['labels'] = []
- targets['obj_ids'] = []
- targets['image_id'] = torch.as_tensor(idx)
- targets['size'] = torch.as_tensor([h, w])
- targets['orig_size'] = torch.as_tensor([h, w])
- for label in labels:
- targets['boxes'].append(label[2:6].tolist())
- targets['area'].append(label[4] * label[5])
- targets['iscrowd'].append(0)
- targets['labels'].append(0)
- obj_id = label[1] + obj_idx_offset if label[1] >= 0 else label[1]
- targets['obj_ids'].append(obj_id) # relative id
-
- targets['area'] = torch.as_tensor(targets['area'])
- targets['iscrowd'] = torch.as_tensor(targets['iscrowd'])
- targets['labels'] = torch.as_tensor(targets['labels'])
- targets['obj_ids'] = torch.as_tensor(targets['obj_ids'])
- targets['boxes'] = torch.as_tensor(targets['boxes'], dtype=torch.float32).reshape(-1, 4)
- # targets['boxes'][:, 0::2].clamp_(min=0, max=w)
- # targets['boxes'][:, 1::2].clamp_(min=0, max=h)
- return img, targets
-
- def _get_sample_range(self, start_idx):
-
- # take default sampling method for normal dataset.
- assert self.sample_mode in ['fixed_interval', 'random_interval'], 'invalid sample mode: {}'.format(self.sample_mode)
- if self.sample_mode == 'fixed_interval':
- sample_interval = self.sample_interval
- elif self.sample_mode == 'random_interval':
- sample_interval = np.random.randint(1, self.sample_interval + 1)
- default_range = start_idx, start_idx + (self.num_frames_per_batch - 1) * sample_interval + 1, sample_interval
- return default_range
-
- def pre_continuous_frames(self, start, end, interval=1):
- targets = []
- images = []
- for i in range(start, end, interval):
- img_i, targets_i = self._pre_single_frame(i)
- images.append(img_i)
- targets.append(targets_i)
- return images, targets
-
- def __getitem__(self, idx):
- sample_start, sample_end, sample_interval = self._get_sample_range(idx)
- images, targets = self.pre_continuous_frames(sample_start, sample_end, sample_interval)
- data = {}
- dataset_name = targets[0]['dataset']
- transform = self.dataset2transform[dataset_name]
- if transform is not None:
- images, targets = transform(images, targets)
- gt_instances = []
- for img_i, targets_i in zip(images, targets):
- gt_instances_i = self._targets_to_instances(targets_i, img_i.shape[1:3])
- gt_instances.append(gt_instances_i)
- data.update({
- 'imgs': images,
- 'gt_instances': gt_instances,
- })
- if self.args.vis:
- data['ori_img'] = [target_i['ori_img'] for target_i in targets]
- return data
-
- def __len__(self):
- return self.item_num
-
-
- class DetMOTDetectionValidation(DetMOTDetection):
- def __init__(self, args, seqs_folder, dataset2transform):
- args.data_txt_path = args.val_data_txt_path
- super().__init__(args, seqs_folder, dataset2transform)
-
-
-
- def make_transforms_for_mot17(image_set, args=None):
-
- normalize = T.MotCompose([
- T.MotToTensor(),
- T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- scales = [608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992]
-
- if image_set == 'train':
- return T.MotCompose([
- T.MotRandomHorizontalFlip(),
- T.MotRandomSelect(
- T.MotRandomResize(scales, max_size=1536),
- T.MotCompose([
- T.MotRandomResize([400, 500, 600]),
- T.FixedMotRandomCrop(384, 600),
- T.MotRandomResize(scales, max_size=1536),
- ])
- ),
- normalize,
- ])
-
- if image_set == 'val':
- return T.MotCompose([
- T.MotRandomResize([800], max_size=1333),
- normalize,
- ])
-
- raise ValueError(f'unknown {image_set}')
-
-
- def make_transforms_for_crowdhuman(image_set, args=None):
-
- normalize = T.MotCompose([
- T.MotToTensor(),
- T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- scales = [608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992]
-
- if image_set == 'train':
- return T.MotCompose([
- T.MotRandomHorizontalFlip(),
- T.FixedMotRandomShift(bs=1),
- T.MotRandomSelect(
- T.MotRandomResize(scales, max_size=1536),
- T.MotCompose([
- T.MotRandomResize([400, 500, 600]),
- T.FixedMotRandomCrop(384, 600),
- T.MotRandomResize(scales, max_size=1536),
- ])
- ),
- normalize,
-
- ])
-
- if image_set == 'val':
- return T.MotCompose([
- T.MotRandomResize([800], max_size=1333),
- normalize,
- ])
-
- raise ValueError(f'unknown {image_set}')
-
-
- def build_dataset2transform(args, image_set):
- mot17_train = make_transforms_for_mot17('train', args)
- mot17_test = make_transforms_for_mot17('val', args)
-
- crowdhuman_train = make_transforms_for_crowdhuman('train', args)
- dataset2transform_train = {'MOT17': mot17_train, 'CrowdHuman': crowdhuman_train}
- dataset2transform_val = {'MOT17': mot17_test, 'CrowdHuman': mot17_test}
- if image_set == 'train':
- return dataset2transform_train
- elif image_set == 'val':
- return dataset2transform_val
- else:
- raise NotImplementedError()
-
-
- def build(image_set, args):
- root = Path(args.mot_path)
- assert root.exists(), f'provided MOT path {root} does not exist'
- dataset2transform = build_dataset2transform(args, image_set)
- if image_set == 'train':
- data_txt_path = args.data_txt_path_train
- dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform)
- if image_set == 'val':
- data_txt_path = args.data_txt_path_val
- dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform)
- return dataset
|