Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

meta_trainer.py 13KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Mahdi Abdollahpour, 27/11/2021, 02:35 PM, PyCharm, ByteTrack
  2. from loguru import logger
  3. import torch
  4. from torch.nn.parallel import DistributedDataParallel as DDP
  5. from torch.utils.tensorboard import SummaryWriter
  6. from yolox.data import DataPrefetcher
  7. from yolox.utils import (
  8. MeterBuffer,
  9. ModelEMA,
  10. all_reduce_norm,
  11. get_model_info,
  12. get_rank,
  13. get_world_size,
  14. gpu_mem_usage,
  15. load_ckpt,
  16. occupy_mem,
  17. save_checkpoint,
  18. setup_logger,
  19. synchronize
  20. )
  21. import datetime
  22. import os
  23. import time
  24. import learn2learn as l2l
  25. class MetaTrainer:
  26. def __init__(self, exp, args):
  27. # init function only defines some basic attr, other attrs like model, optimizer are built in
  28. # before_train methods.
  29. self.exp = exp
  30. self.args = args
  31. # training related attr
  32. self.max_epoch = exp.max_epoch
  33. self.amp_training = args.fp16
  34. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  35. self.is_distributed = get_world_size() > 1
  36. self.rank = get_rank()
  37. self.local_rank = args.local_rank
  38. self.device = "cuda:{}".format(self.local_rank)
  39. self.use_model_ema = exp.ema
  40. # data/dataloader related attr
  41. self.data_type = torch.float16 if args.fp16 else torch.float32
  42. self.input_size = exp.input_size
  43. self.best_ap = 0
  44. self.adaptation_period = args.adaptation_period
  45. # metric record
  46. self.meter = MeterBuffer(window_size=exp.print_interval)
  47. self.file_name = os.path.join(exp.output_dir, args.experiment_name)
  48. if self.rank == 0:
  49. os.makedirs(self.file_name, exist_ok=True)
  50. setup_logger(
  51. self.file_name,
  52. distributed_rank=self.rank,
  53. filename="train_log.txt",
  54. mode="a",
  55. )
  56. def train(self):
  57. self.before_train()
  58. try:
  59. self.train_in_epoch()
  60. except Exception:
  61. raise
  62. finally:
  63. self.after_train()
  64. def train_in_epoch(self):
  65. # self.evaluate_and_save_model()
  66. for self.epoch in range(self.start_epoch, self.max_epoch):
  67. self.before_epoch()
  68. self.train_in_task()
  69. self.after_epoch()
  70. def train_in_iter(self):
  71. for self.iter in range(len(self.train_loader)):
  72. self.before_iter()
  73. self.train_one_iter()
  74. self.after_iter()
  75. def train_in_task(self):
  76. for i in range(len(self.train_loaders)):
  77. self.before_task(i)
  78. self.train_in_iter()
  79. self.after_task(i)
  80. def before_task(self, i):
  81. # logger.info("init prefetcher, this might take one minute or less...")
  82. self.train_loader = self.train_loaders[i]
  83. self.prefetcher = self.prefetchers[i]
  84. self.learner = self.model.clone()
  85. logger.info("model clone created. dataloader:{}".format(i))
  86. def after_task(self,i):
  87. self.epoch_progress += len(self.train_loaders[i])
  88. def adapt(self, inps, targets):
  89. # adapt_inps =inps[:1, ...]
  90. # targets_inps =targets[:1, ...]
  91. # print(adapt_inps.shape)
  92. # print(targets_inps.shape)
  93. outputs = self.learner(inps, targets)
  94. loss = outputs["total_loss"]
  95. self.learner.adapt(loss)
  96. del outputs, loss
  97. def train_one_iter(self):
  98. iter_start_time = time.time()
  99. inps, targets = self.prefetcher.next()
  100. inps = inps.to(self.data_type)
  101. targets = targets.to(self.data_type)
  102. targets.requires_grad = False
  103. data_end_time = time.time()
  104. # logger.info("in train iter")
  105. with torch.cuda.amp.autocast(enabled=self.amp_training):
  106. if self.iter % self.adaptation_period == 0:
  107. self.adapt(inps, targets)
  108. outputs = self.learner(inps, targets)
  109. loss = outputs["total_loss"]
  110. for p in self.exp.all_parameters:
  111. if p.grad is not None:
  112. p.grad.data.mul_(1.0 / self.args.batch_size)
  113. self.optimizer.zero_grad()
  114. # logger.info("loss Norm: {} , scale {}".format(torch.norm(loss), self.scaler.get_scale()))
  115. # loss = self.scaler.scale(loss)
  116. # logger.info("loss Norm: {} , scale {}".format(torch.norm(loss), self.scaler.get_scale()))
  117. self.scaler.scale(loss).backward()
  118. # loss.backward()
  119. self.scaler.step(self.optimizer)
  120. self.scaler.update()
  121. if self.use_model_ema:
  122. self.ema_model.update(self.model)
  123. lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
  124. for param_group in self.optimizer.param_groups:
  125. param_group["lr"] = lr
  126. iter_end_time = time.time()
  127. self.meter.update(
  128. iter_time=iter_end_time - iter_start_time,
  129. data_time=data_end_time - iter_start_time,
  130. lr=lr,
  131. **outputs,
  132. )
  133. del loss, outputs
  134. def before_train(self):
  135. logger.info("args: {}".format(self.args))
  136. # logger.info("exp value:\n{}".format(self.exp))
  137. # model related init
  138. torch.cuda.set_device(self.local_rank)
  139. model = self.exp.get_model()
  140. logger.info(
  141. "Model Summary: {}".format(get_model_info(model, self.exp.test_size))
  142. )
  143. # exit()
  144. model.to(self.device)
  145. # from torchsummary import summary
  146. # summary(model, input_size=(3, 300, 300), device='cuda')
  147. # value of epoch will be set in `resume_train`
  148. self.model = l2l.algorithms.MAML(model, lr=self.exp.inner_lr, first_order=self.exp.first_order)
  149. # solver related init
  150. self.optimizer = self.exp.get_optimizer(self.args.batch_size)
  151. self.model = self.resume_train(self.model)
  152. # data related init
  153. self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
  154. logger.info('Getting data loaders')
  155. self.train_loaders = self.exp.get_data_loaders(
  156. batch_size=self.args.batch_size,
  157. is_distributed=self.is_distributed,
  158. no_aug=self.no_aug,
  159. )
  160. # max_iter means iters per epoch
  161. self.max_iter = 0
  162. self.prefetchers = []
  163. for i, train_loader in enumerate(self.train_loaders):
  164. self.max_iter += len(train_loader)
  165. self.prefetchers.append(DataPrefetcher(train_loader))
  166. self.lr_scheduler = self.exp.get_lr_scheduler(
  167. self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
  168. )
  169. if self.args.occupy:
  170. occupy_mem(self.local_rank)
  171. if self.is_distributed:
  172. self.model = DDP(self.model, device_ids=[self.local_rank], broadcast_buffers=False)
  173. if self.use_model_ema:
  174. self.ema_model = ModelEMA(self.model, 0.9998)
  175. self.ema_model.updates = self.max_iter * self.start_epoch
  176. # self.model = model
  177. self.model.train()
  178. self.evaluators = self.exp.get_evaluators(
  179. batch_size=self.args.batch_size, is_distributed=self.is_distributed
  180. )
  181. # Tensorboard logger
  182. if self.rank == 0:
  183. self.tblogger = SummaryWriter(self.file_name)
  184. logger.info("Training start...")
  185. # logger.info("\n{}".format(model))
  186. def after_train(self):
  187. logger.info(
  188. "Training of experiment is done and the best AP is {:.2f}".format(
  189. self.best_ap * 100
  190. )
  191. )
  192. def before_epoch(self):
  193. logger.info("---> start train epoch{}".format(self.epoch + 1))
  194. self.epoch_progress = 0
  195. if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
  196. logger.info("--->No mosaic aug now!")
  197. for train_loader in self.train_loaders:
  198. train_loader.close_mosaic()
  199. logger.info("--->Add additional L1 loss now!")
  200. if self.is_distributed:
  201. self.model.module.head.use_l1 = True
  202. else:
  203. self.model.head.use_l1 = True
  204. self.exp.eval_interval = 1
  205. if not self.no_aug:
  206. self.save_ckpt(ckpt_name="last_mosaic_epoch")
  207. def after_epoch(self):
  208. self.epoch_progress = 0
  209. if self.use_model_ema:
  210. self.ema_model.update_attr(self.model)
  211. self.save_ckpt(ckpt_name="latest")
  212. if (self.epoch + 1) % self.exp.eval_interval == 0:
  213. all_reduce_norm(self.model)
  214. self.evaluate_and_save_model()
  215. def before_iter(self):
  216. pass
  217. def after_iter(self):
  218. """
  219. `after_iter` contains two parts of logic:
  220. * log information
  221. * reset setting of resize
  222. """
  223. # log needed information
  224. if (self.iter + 1) % self.exp.print_interval == 0:
  225. # TODO check ETA logic
  226. left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
  227. eta_seconds = self.meter["iter_time"].global_avg * left_iters
  228. eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
  229. progress_str = "epoch: {}/{}, iter: {}/{}".format(
  230. self.epoch + 1, self.max_epoch, self.iter + 1 + self.epoch_progress, self.max_iter
  231. )
  232. loss_meter = self.meter.get_filtered_meter("loss")
  233. loss_str = ", ".join(
  234. ["{}: {:.3f}".format(k, v.latest) for k, v in loss_meter.items()]
  235. )
  236. time_meter = self.meter.get_filtered_meter("time")
  237. time_str = ", ".join(
  238. ["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
  239. )
  240. logger.info(
  241. "{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format(
  242. progress_str,
  243. gpu_mem_usage(),
  244. time_str,
  245. loss_str,
  246. self.meter["lr"].latest,
  247. )
  248. + (", size: {:d}, {}".format(self.input_size[0], eta_str))
  249. )
  250. self.meter.clear_meters()
  251. # random resizing
  252. if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0:
  253. self.input_size = self.exp.random_resize(
  254. self.train_loader, self.epoch, self.rank, self.is_distributed
  255. )
  256. @property
  257. def progress_in_iter(self):
  258. return self.epoch * self.max_iter + self.iter + self.epoch_progress
  259. def resume_train(self, model):
  260. if self.args.resume:
  261. logger.info("resume training")
  262. if self.args.ckpt is None:
  263. ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth.tar")
  264. else:
  265. ckpt_file = self.args.ckpt
  266. ckpt = torch.load(ckpt_file, map_location=self.device)
  267. # TODO: handle pretrained BYTETrack
  268. # handling meta models
  269. # new_dict = {}
  270. # for key in ckpt["model"].keys():
  271. # if key.startswith('module.'):
  272. # new_dict[key[7:]] = ckpt["model"][key]
  273. # else:
  274. # new_dict[key] = ckpt["model"][key]
  275. # del ckpt["model"]
  276. # ckpt["model"] = new_dict
  277. # resume the model/optimizer state dict
  278. model.load_state_dict(ckpt["model"])
  279. self.optimizer.load_state_dict(ckpt["optimizer"])
  280. start_epoch = (
  281. self.args.start_epoch - 1
  282. if self.args.start_epoch is not None
  283. else ckpt["start_epoch"]
  284. )
  285. self.start_epoch = start_epoch
  286. logger.info(
  287. "loaded checkpoint '{}' (epoch {})".format(
  288. self.args.resume, self.start_epoch
  289. )
  290. ) # noqa
  291. else:
  292. if self.args.ckpt is not None:
  293. logger.info("loading checkpoint for fine tuning")
  294. ckpt_file = self.args.ckpt
  295. ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
  296. model = load_ckpt(model, ckpt)
  297. self.start_epoch = 0
  298. return model
  299. def evaluate_and_save_model(self):
  300. logger.info("starting eval...")
  301. evalmodel = self.ema_model.ema if self.use_model_ema else self.model
  302. ap50_95, ap50, summary = self.exp.eval(
  303. evalmodel, self.evaluators, self.is_distributed
  304. )
  305. self.model.train()
  306. if self.rank == 0:
  307. self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
  308. self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
  309. logger.info("\n" + summary)
  310. synchronize()
  311. # self.best_ap = max(self.best_ap, ap50_95)
  312. self.save_ckpt("last_epoch", ap50 > self.best_ap)
  313. self.best_ap = max(self.best_ap, ap50)
  314. def save_ckpt(self, ckpt_name, update_best_ckpt=False):
  315. if self.rank == 0:
  316. save_model = self.ema_model.ema if self.use_model_ema else self.model
  317. logger.info("Save weights to {}".format(self.file_name))
  318. ckpt_state = {
  319. "start_epoch": self.epoch + 1,
  320. "model": save_model.state_dict(),
  321. "optimizer": self.optimizer.state_dict(),
  322. }
  323. save_checkpoint(
  324. ckpt_state,
  325. update_best_ckpt,
  326. self.file_name,
  327. ckpt_name,
  328. )