Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

opts.py 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import argparse
  5. import os
  6. import sys
  7. class opts(object):
  8. def __init__(self):
  9. self.parser = argparse.ArgumentParser()
  10. # basic experiment setting
  11. self.parser.add_argument('task', default='',
  12. help='ctdet | ddd | multi_pose '
  13. '| tracking or combined with ,')
  14. self.parser.add_argument('--dataset', default='coco',
  15. help='see lib/dataset/dataset_facotry for ' +
  16. 'available datasets')
  17. self.parser.add_argument('--test_dataset', default='',
  18. help='coco | kitti | coco_hp | pascal')
  19. self.parser.add_argument('--exp_id', default='default')
  20. self.parser.add_argument('--test', action='store_true')
  21. self.parser.add_argument('--debug', type=int, default=0,
  22. help='level of visualization.'
  23. '1: only show the final detection results'
  24. '2: show the network output features'
  25. '3: use matplot to display' # useful when lunching training with ipython notebook
  26. '4: save all visualizations to disk')
  27. self.parser.add_argument('--no_pause', action='store_true')
  28. self.parser.add_argument('--demo', default='',
  29. help='path to image/ image folders/ video. '
  30. 'or "webcam"')
  31. self.parser.add_argument('--load_model', default='',
  32. help='path to pretrained model')
  33. self.parser.add_argument('--resume', action='store_true',
  34. help='resume an experiment. '
  35. 'Reloaded the optimizer parameter and '
  36. 'set load_model to model_last.pth '
  37. 'in the exp dir if load_model is empty.')
  38. # system
  39. self.parser.add_argument('--gpus', default='0',
  40. help='-1 for CPU, use comma for multiple gpus')
  41. self.parser.add_argument('--num_workers', type=int, default=4,
  42. help='dataloader threads. 0 for single-thread.')
  43. self.parser.add_argument('--not_cuda_benchmark', action='store_true',
  44. help='disable when the input size is not fixed.')
  45. self.parser.add_argument('--seed', type=int, default=317,
  46. help='random seed') # from CornerNet
  47. self.parser.add_argument('--not_set_cuda_env', action='store_true',
  48. help='used when training in slurm clusters.')
  49. # log
  50. self.parser.add_argument('--print_iter', type=int, default=0,
  51. help='disable progress bar and print to screen.')
  52. self.parser.add_argument('--save_all', action='store_true',
  53. help='save model to disk every 5 epochs.')
  54. self.parser.add_argument('--vis_thresh', type=float, default=0.3,
  55. help='visualization threshold.')
  56. self.parser.add_argument('--debugger_theme', default='white',
  57. choices=['white', 'black'])
  58. self.parser.add_argument('--eval_val', action='store_true')
  59. self.parser.add_argument('--save_imgs', default='', help='')
  60. self.parser.add_argument('--save_img_suffix', default='', help='')
  61. self.parser.add_argument('--skip_first', type=int, default=-1, help='')
  62. self.parser.add_argument('--save_video', action='store_true')
  63. self.parser.add_argument('--save_framerate', type=int, default=30)
  64. self.parser.add_argument('--resize_video', action='store_true')
  65. self.parser.add_argument('--video_h', type=int, default=512, help='')
  66. self.parser.add_argument('--video_w', type=int, default=512, help='')
  67. self.parser.add_argument('--transpose_video', action='store_true')
  68. self.parser.add_argument('--show_track_color', action='store_true')
  69. self.parser.add_argument('--not_show_bbox', action='store_true')
  70. self.parser.add_argument('--not_show_number', action='store_true')
  71. self.parser.add_argument('--not_show_txt', action='store_true')
  72. self.parser.add_argument('--qualitative', action='store_true')
  73. self.parser.add_argument('--tango_color', action='store_true')
  74. self.parser.add_argument('--only_show_dots', action='store_true')
  75. self.parser.add_argument('--show_trace', action='store_true')
  76. # model
  77. self.parser.add_argument('--arch', default='dla_34',
  78. help='model architecture. Currently tested'
  79. 'res_18 | res_101 | resdcn_18 | resdcn_101 |'
  80. 'dlav0_34 | dla_34 | hourglass')
  81. self.parser.add_argument('--dla_node', default='dcn')
  82. self.parser.add_argument('--head_conv', type=int, default=-1,
  83. help='conv layer channels for output head'
  84. '0 for no conv layer'
  85. '-1 for default setting: '
  86. '64 for resnets and 256 for dla.')
  87. self.parser.add_argument('--num_head_conv', type=int, default=1)
  88. self.parser.add_argument('--head_kernel', type=int, default=3, help='')
  89. self.parser.add_argument('--down_ratio', type=int, default=4,
  90. help='output stride. Currently only supports 4.')
  91. self.parser.add_argument('--not_idaup', action='store_true')
  92. self.parser.add_argument('--num_classes', type=int, default=-1)
  93. self.parser.add_argument('--num_layers', type=int, default=101)
  94. self.parser.add_argument('--backbone', default='dla34')
  95. self.parser.add_argument('--neck', default='dlaup')
  96. self.parser.add_argument('--msra_outchannel', type=int, default=256)
  97. self.parser.add_argument('--efficient_level', type=int, default=0)
  98. self.parser.add_argument('--prior_bias', type=float, default=-4.6) # -2.19
  99. # input
  100. self.parser.add_argument('--input_res', type=int, default=-1,
  101. help='input height and width. -1 for default from '
  102. 'dataset. Will be overriden by input_h | input_w')
  103. self.parser.add_argument('--input_h', type=int, default=-1,
  104. help='input height. -1 for default from dataset.')
  105. self.parser.add_argument('--input_w', type=int, default=-1,
  106. help='input width. -1 for default from dataset.')
  107. self.parser.add_argument('--dataset_version', default='')
  108. # train
  109. self.parser.add_argument('--optim', default='adam')
  110. self.parser.add_argument('--lr', type=float, default=1.25e-4,
  111. help='learning rate for batch size 32.')
  112. self.parser.add_argument('--lr_step', type=str, default='60',
  113. help='drop learning rate by 10.')
  114. self.parser.add_argument('--save_point', type=str, default='90',
  115. help='when to save the model to disk.')
  116. self.parser.add_argument('--num_epochs', type=int, default=70,
  117. help='total training epochs.')
  118. self.parser.add_argument('--batch_size', type=int, default=32,
  119. help='batch size')
  120. self.parser.add_argument('--master_batch_size', type=int, default=-1,
  121. help='batch size on the master gpu.')
  122. self.parser.add_argument('--num_iters', type=int, default=-1,
  123. help='default: #samples / batch_size.')
  124. self.parser.add_argument('--val_intervals', type=int, default=10000,
  125. help='number of epochs to run validation.')
  126. self.parser.add_argument('--trainval', action='store_true',
  127. help='include validation in training and '
  128. 'test on test set')
  129. self.parser.add_argument('--ltrb', action='store_true',
  130. help='')
  131. self.parser.add_argument('--ltrb_weight', type=float, default=0.1,
  132. help='')
  133. self.parser.add_argument('--reset_hm', action='store_true')
  134. self.parser.add_argument('--reuse_hm', action='store_true')
  135. self.parser.add_argument('--use_kpt_center', action='store_true')
  136. self.parser.add_argument('--add_05', action='store_true')
  137. self.parser.add_argument('--dense_reg', type=int, default=1, help='')
  138. # test
  139. self.parser.add_argument('--flip_test', action='store_true',
  140. help='flip data augmentation.')
  141. self.parser.add_argument('--test_scales', type=str, default='1',
  142. help='multi scale test augmentation.')
  143. self.parser.add_argument('--nms', action='store_true',
  144. help='run nms in testing.')
  145. self.parser.add_argument('--K', type=int, default=100,
  146. help='max number of output objects.')
  147. self.parser.add_argument('--not_prefetch_test', action='store_true',
  148. help='not use parallal data pre-processing.')
  149. self.parser.add_argument('--fix_short', type=int, default=-1)
  150. self.parser.add_argument('--keep_res', action='store_true',
  151. help='keep the original resolution'
  152. ' during validation.')
  153. self.parser.add_argument('--map_argoverse_id', action='store_true',
  154. help='if trained on nuscenes and eval on kitti')
  155. self.parser.add_argument('--out_thresh', type=float, default=-1,
  156. help='')
  157. self.parser.add_argument('--depth_scale', type=float, default=1,
  158. help='')
  159. self.parser.add_argument('--save_results', action='store_true')
  160. self.parser.add_argument('--load_results', default='')
  161. self.parser.add_argument('--use_loaded_results', action='store_true')
  162. self.parser.add_argument('--ignore_loaded_cats', default='')
  163. self.parser.add_argument('--model_output_list', action='store_true',
  164. help='Used when convert to onnx')
  165. self.parser.add_argument('--non_block_test', action='store_true')
  166. self.parser.add_argument('--vis_gt_bev', default='', help='')
  167. self.parser.add_argument('--kitti_split', default='3dop',
  168. help='different validation split for kitti: '
  169. '3dop | subcnn')
  170. self.parser.add_argument('--test_focal_length', type=int, default=-1)
  171. # dataset
  172. self.parser.add_argument('--not_rand_crop', action='store_true',
  173. help='not use the random crop data augmentation'
  174. 'from CornerNet.')
  175. self.parser.add_argument('--not_max_crop', action='store_true',
  176. help='used when the training dataset has'
  177. 'inbalanced aspect ratios.')
  178. self.parser.add_argument('--shift', type=float, default=0,
  179. help='when not using random crop, 0.1'
  180. 'apply shift augmentation.')
  181. self.parser.add_argument('--scale', type=float, default=0,
  182. help='when not using random crop, 0.4'
  183. 'apply scale augmentation.')
  184. self.parser.add_argument('--aug_rot', type=float, default=0,
  185. help='probability of applying '
  186. 'rotation augmentation.')
  187. self.parser.add_argument('--rotate', type=float, default=0,
  188. help='when not using random crop'
  189. 'apply rotation augmentation.')
  190. self.parser.add_argument('--flip', type=float, default=0.5,
  191. help='probability of applying flip augmentation.')
  192. self.parser.add_argument('--no_color_aug', action='store_true',
  193. help='not use the color augmenation '
  194. 'from CornerNet')
  195. # Tracking
  196. self.parser.add_argument('--tracking', action='store_true')
  197. self.parser.add_argument('--pre_hm', action='store_true')
  198. self.parser.add_argument('--same_aug_pre', action='store_true')
  199. self.parser.add_argument('--zero_pre_hm', action='store_true')
  200. self.parser.add_argument('--hm_disturb', type=float, default=0)
  201. self.parser.add_argument('--lost_disturb', type=float, default=0)
  202. self.parser.add_argument('--fp_disturb', type=float, default=0)
  203. self.parser.add_argument('--pre_thresh', type=float, default=-1)
  204. self.parser.add_argument('--track_thresh', type=float, default=0.3)
  205. self.parser.add_argument('--match_thresh', type=float, default=0.8)
  206. self.parser.add_argument('--track_buffer', type=int, default=30)
  207. self.parser.add_argument('--new_thresh', type=float, default=0.3)
  208. self.parser.add_argument('--max_frame_dist', type=int, default=3)
  209. self.parser.add_argument('--ltrb_amodal', action='store_true')
  210. self.parser.add_argument('--ltrb_amodal_weight', type=float, default=0.1)
  211. self.parser.add_argument('--public_det', action='store_true')
  212. self.parser.add_argument('--no_pre_img', action='store_true')
  213. self.parser.add_argument('--zero_tracking', action='store_true')
  214. self.parser.add_argument('--hungarian', action='store_true')
  215. self.parser.add_argument('--max_age', type=int, default=-1)
  216. # loss
  217. self.parser.add_argument('--tracking_weight', type=float, default=1)
  218. self.parser.add_argument('--reg_loss', default='l1',
  219. help='regression loss: sl1 | l1 | l2')
  220. self.parser.add_argument('--hm_weight', type=float, default=1,
  221. help='loss weight for keypoint heatmaps.')
  222. self.parser.add_argument('--off_weight', type=float, default=1,
  223. help='loss weight for keypoint local offsets.')
  224. self.parser.add_argument('--wh_weight', type=float, default=0.1,
  225. help='loss weight for bounding box size.')
  226. self.parser.add_argument('--hp_weight', type=float, default=1,
  227. help='loss weight for human pose offset.')
  228. self.parser.add_argument('--hm_hp_weight', type=float, default=1,
  229. help='loss weight for human keypoint heatmap.')
  230. self.parser.add_argument('--amodel_offset_weight', type=float, default=1,
  231. help='Please forgive the typo.')
  232. self.parser.add_argument('--dep_weight', type=float, default=1,
  233. help='loss weight for depth.')
  234. self.parser.add_argument('--dim_weight', type=float, default=1,
  235. help='loss weight for 3d bounding box size.')
  236. self.parser.add_argument('--rot_weight', type=float, default=1,
  237. help='loss weight for orientation.')
  238. self.parser.add_argument('--nuscenes_att', action='store_true')
  239. self.parser.add_argument('--nuscenes_att_weight', type=float, default=1)
  240. self.parser.add_argument('--velocity', action='store_true')
  241. self.parser.add_argument('--velocity_weight', type=float, default=1)
  242. # custom dataset
  243. self.parser.add_argument('--custom_dataset_img_path', default='')
  244. self.parser.add_argument('--custom_dataset_ann_path', default='')
  245. self.parser.add_argument('--bird_view_world_size', type=int, default=64)
  246. def parse(self, args=''):
  247. if args == '':
  248. opt = self.parser.parse_args()
  249. else:
  250. opt = self.parser.parse_args(args)
  251. if opt.test_dataset == '':
  252. opt.test_dataset = opt.dataset
  253. opt.gpus_str = opt.gpus
  254. opt.gpus = [int(gpu) for gpu in opt.gpus.split(',')]
  255. opt.gpus = [i for i in range(len(opt.gpus))] if opt.gpus[0] >=0 else [-1]
  256. opt.lr_step = [int(i) for i in opt.lr_step.split(',')]
  257. opt.save_point = [int(i) for i in opt.save_point.split(',')]
  258. opt.test_scales = [float(i) for i in opt.test_scales.split(',')]
  259. opt.save_imgs = [i for i in opt.save_imgs.split(',')] \
  260. if opt.save_imgs != '' else []
  261. opt.ignore_loaded_cats = \
  262. [int(i) for i in opt.ignore_loaded_cats.split(',')] \
  263. if opt.ignore_loaded_cats != '' else []
  264. opt.num_workers = max(opt.num_workers, 2 * len(opt.gpus))
  265. opt.pre_img = False
  266. if 'tracking' in opt.task:
  267. print('Running tracking')
  268. opt.tracking = True
  269. # opt.out_thresh = max(opt.track_thresh, opt.out_thresh)
  270. # opt.pre_thresh = max(opt.track_thresh, opt.pre_thresh)
  271. # opt.new_thresh = max(opt.track_thresh, opt.new_thresh)
  272. opt.pre_img = not opt.no_pre_img
  273. print('Using tracking threshold for out threshold!', opt.track_thresh)
  274. if 'ddd' in opt.task:
  275. opt.show_track_color = True
  276. opt.fix_res = not opt.keep_res
  277. print('Fix size testing.' if opt.fix_res else 'Keep resolution testing.')
  278. if opt.head_conv == -1: # init default head_conv
  279. opt.head_conv = 256 if 'dla' in opt.arch else 64
  280. opt.pad = 127 if 'hourglass' in opt.arch else 31
  281. opt.num_stacks = 2 if opt.arch == 'hourglass' else 1
  282. if opt.master_batch_size == -1:
  283. opt.master_batch_size = opt.batch_size // len(opt.gpus)
  284. rest_batch_size = (opt.batch_size - opt.master_batch_size)
  285. opt.chunk_sizes = [opt.master_batch_size]
  286. for i in range(len(opt.gpus) - 1):
  287. slave_chunk_size = rest_batch_size // (len(opt.gpus) - 1)
  288. if i < rest_batch_size % (len(opt.gpus) - 1):
  289. slave_chunk_size += 1
  290. opt.chunk_sizes.append(slave_chunk_size)
  291. print('training chunk_sizes:', opt.chunk_sizes)
  292. if opt.debug > 0:
  293. opt.num_workers = 0
  294. opt.batch_size = 1
  295. opt.gpus = [opt.gpus[0]]
  296. opt.master_batch_size = -1
  297. # log dirs
  298. opt.root_dir = os.path.join(os.path.dirname(__file__), '..', '..')
  299. opt.data_dir = os.path.join(opt.root_dir, 'data')
  300. opt.exp_dir = os.path.join(opt.root_dir, 'exp', opt.task)
  301. opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id)
  302. opt.debug_dir = os.path.join(opt.save_dir, 'debug')
  303. if opt.resume and opt.load_model == '':
  304. opt.load_model = os.path.join(opt.save_dir, 'model_last.pth')
  305. return opt
  306. def update_dataset_info_and_set_heads(self, opt, dataset):
  307. opt.num_classes = dataset.num_categories \
  308. if opt.num_classes < 0 else opt.num_classes
  309. # input_h(w): opt.input_h overrides opt.input_res overrides dataset default
  310. input_h, input_w = dataset.default_resolution
  311. input_h = opt.input_res if opt.input_res > 0 else input_h
  312. input_w = opt.input_res if opt.input_res > 0 else input_w
  313. opt.input_h = opt.input_h if opt.input_h > 0 else input_h
  314. opt.input_w = opt.input_w if opt.input_w > 0 else input_w
  315. opt.output_h = opt.input_h // opt.down_ratio
  316. opt.output_w = opt.input_w // opt.down_ratio
  317. opt.input_res = max(opt.input_h, opt.input_w)
  318. opt.output_res = max(opt.output_h, opt.output_w)
  319. opt.heads = {'hm': opt.num_classes, 'reg': 2, 'wh': 2}
  320. if 'tracking' in opt.task:
  321. opt.heads.update({'tracking': 2})
  322. if 'ddd' in opt.task:
  323. opt.heads.update({'dep': 1, 'rot': 8, 'dim': 3, 'amodel_offset': 2})
  324. if 'multi_pose' in opt.task:
  325. opt.heads.update({
  326. 'hps': dataset.num_joints * 2, 'hm_hp': dataset.num_joints,
  327. 'hp_offset': 2})
  328. if opt.ltrb:
  329. opt.heads.update({'ltrb': 4})
  330. if opt.ltrb_amodal:
  331. opt.heads.update({'ltrb_amodal': 4})
  332. if opt.nuscenes_att:
  333. opt.heads.update({'nuscenes_att': 8})
  334. if opt.velocity:
  335. opt.heads.update({'velocity': 3})
  336. weight_dict = {'hm': opt.hm_weight, 'wh': opt.wh_weight,
  337. 'reg': opt.off_weight, 'hps': opt.hp_weight,
  338. 'hm_hp': opt.hm_hp_weight, 'hp_offset': opt.off_weight,
  339. 'dep': opt.dep_weight, 'rot': opt.rot_weight,
  340. 'dim': opt.dim_weight,
  341. 'amodel_offset': opt.amodel_offset_weight,
  342. 'ltrb': opt.ltrb_weight,
  343. 'tracking': opt.tracking_weight,
  344. 'ltrb_amodal': opt.ltrb_amodal_weight,
  345. 'nuscenes_att': opt.nuscenes_att_weight,
  346. 'velocity': opt.velocity_weight}
  347. opt.weights = {head: weight_dict[head] for head in opt.heads}
  348. for head in opt.weights:
  349. if opt.weights[head] == 0:
  350. del opt.heads[head]
  351. opt.head_conv = {head: [opt.head_conv \
  352. for i in range(opt.num_head_conv if head != 'reg' else 1)] for head in opt.heads}
  353. print('input h w:', opt.input_h, opt.input_w)
  354. print('heads', opt.heads)
  355. print('weights', opt.weights)
  356. print('head conv', opt.head_conv)
  357. return opt
  358. def init(self, args=''):
  359. # only used in demo
  360. default_dataset_info = {
  361. 'ctdet': 'coco', 'multi_pose': 'coco_hp', 'ddd': 'nuscenes',
  362. 'tracking,ctdet': 'coco', 'tracking,multi_pose': 'coco_hp',
  363. 'tracking,ddd': 'nuscenes'
  364. }
  365. opt = self.parse()
  366. from dataset.dataset_factory import dataset_factory
  367. train_dataset = default_dataset_info[opt.task] \
  368. if opt.task in default_dataset_info else 'coco'
  369. dataset = dataset_factory[train_dataset]
  370. opt = self.update_dataset_info_and_set_heads(opt, dataset)
  371. return opt