Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. from loguru import logger
  5. import torch
  6. from torch.nn.parallel import DistributedDataParallel as DDP
  7. from torch.utils.tensorboard import SummaryWriter
  8. from yolox.data import DataPrefetcher
  9. from yolox.utils import (
  10. MeterBuffer,
  11. ModelEMA,
  12. all_reduce_norm,
  13. get_model_info,
  14. get_rank,
  15. get_world_size,
  16. gpu_mem_usage,
  17. load_ckpt,
  18. occupy_mem,
  19. save_checkpoint,
  20. setup_logger,
  21. synchronize
  22. )
  23. import datetime
  24. import os
  25. import time
  26. class Trainer:
  27. def __init__(self, exp, args):
  28. # init function only defines some basic attr, other attrs like model, optimizer are built in
  29. # before_train methods.
  30. self.exp = exp
  31. self.args = args
  32. # training related attr
  33. self.max_epoch = exp.max_epoch
  34. self.amp_training = args.fp16
  35. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  36. self.is_distributed = get_world_size() > 1
  37. self.rank = get_rank()
  38. self.local_rank = args.local_rank
  39. self.device = "cuda:{}".format(self.local_rank)
  40. self.use_model_ema = exp.ema
  41. # data/dataloader related attr
  42. self.data_type = torch.float16 if args.fp16 else torch.float32
  43. self.input_size = exp.input_size
  44. self.best_ap = 0
  45. # metric record
  46. self.meter = MeterBuffer(window_size=exp.print_interval)
  47. self.file_name = os.path.join(exp.output_dir, args.experiment_name)
  48. if self.rank == 0:
  49. os.makedirs(self.file_name, exist_ok=True)
  50. setup_logger(
  51. self.file_name,
  52. distributed_rank=self.rank,
  53. filename="train_log.txt",
  54. mode="a",
  55. )
  56. def train(self):
  57. self.before_train()
  58. try:
  59. self.train_in_epoch()
  60. except Exception:
  61. raise
  62. finally:
  63. self.after_train()
  64. def train_in_epoch(self):
  65. for self.epoch in range(self.start_epoch, self.max_epoch):
  66. self.before_epoch()
  67. self.train_in_iter()
  68. self.after_epoch()
  69. def train_in_iter(self):
  70. for self.iter in range(self.max_iter):
  71. self.before_iter()
  72. self.train_one_iter()
  73. self.after_iter()
  74. def train_one_iter(self):
  75. iter_start_time = time.time()
  76. inps, targets = self.prefetcher.next()
  77. inps = inps.to(self.data_type)
  78. targets = targets.to(self.data_type)
  79. targets.requires_grad = False
  80. data_end_time = time.time()
  81. with torch.cuda.amp.autocast(enabled=self.amp_training):
  82. outputs = self.model(inps, targets)
  83. loss = outputs["total_loss"]
  84. self.optimizer.zero_grad()
  85. self.scaler.scale(loss).backward()
  86. self.scaler.step(self.optimizer)
  87. self.scaler.update()
  88. if self.use_model_ema:
  89. self.ema_model.update(self.model)
  90. lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
  91. for param_group in self.optimizer.param_groups:
  92. param_group["lr"] = lr
  93. iter_end_time = time.time()
  94. self.meter.update(
  95. iter_time=iter_end_time - iter_start_time,
  96. data_time=data_end_time - iter_start_time,
  97. lr=lr,
  98. **outputs,
  99. )
  100. def before_train(self):
  101. logger.info("args: {}".format(self.args))
  102. logger.info("exp value:\n{}".format(self.exp))
  103. # model related init
  104. torch.cuda.set_device(self.local_rank)
  105. model = self.exp.get_model()
  106. logger.info(
  107. "Model Summary: {}".format(get_model_info(model, self.exp.test_size))
  108. )
  109. model.to(self.device)
  110. # solver related init
  111. self.optimizer = self.exp.get_optimizer(self.args.batch_size)
  112. # value of epoch will be set in `resume_train`
  113. model = self.resume_train(model)
  114. # data related init
  115. self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
  116. self.train_loader = self.exp.get_data_loader(
  117. batch_size=self.args.batch_size,
  118. is_distributed=self.is_distributed,
  119. no_aug=self.no_aug,
  120. )
  121. logger.info("init prefetcher, this might take one minute or less...")
  122. self.prefetcher = DataPrefetcher(self.train_loader)
  123. # max_iter means iters per epoch
  124. self.max_iter = len(self.train_loader)
  125. self.lr_scheduler = self.exp.get_lr_scheduler(
  126. self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
  127. )
  128. if self.args.occupy:
  129. occupy_mem(self.local_rank)
  130. if self.is_distributed:
  131. model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
  132. if self.use_model_ema:
  133. self.ema_model = ModelEMA(model, 0.9998)
  134. self.ema_model.updates = self.max_iter * self.start_epoch
  135. self.model = model
  136. self.model.train()
  137. self.evaluator = self.exp.get_evaluator(
  138. batch_size=self.args.batch_size, is_distributed=self.is_distributed
  139. )
  140. # Tensorboard logger
  141. if self.rank == 0:
  142. self.tblogger = SummaryWriter(self.file_name)
  143. logger.info("Training start...")
  144. #logger.info("\n{}".format(model))
  145. def after_train(self):
  146. logger.info(
  147. "Training of experiment is done and the best AP is {:.2f}".format(
  148. self.best_ap * 100
  149. )
  150. )
  151. def before_epoch(self):
  152. logger.info("---> start train epoch{}".format(self.epoch + 1))
  153. if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
  154. logger.info("--->No mosaic aug now!")
  155. self.train_loader.close_mosaic()
  156. logger.info("--->Add additional L1 loss now!")
  157. if self.is_distributed:
  158. self.model.module.head.use_l1 = True
  159. else:
  160. self.model.head.use_l1 = True
  161. self.exp.eval_interval = 1
  162. if not self.no_aug:
  163. self.save_ckpt(ckpt_name="last_mosaic_epoch")
  164. def after_epoch(self):
  165. if self.use_model_ema:
  166. self.ema_model.update_attr(self.model)
  167. self.save_ckpt(ckpt_name="latest")
  168. if (self.epoch + 1) % self.exp.eval_interval == 0:
  169. all_reduce_norm(self.model)
  170. self.evaluate_and_save_model()
  171. def before_iter(self):
  172. pass
  173. def after_iter(self):
  174. """
  175. `after_iter` contains two parts of logic:
  176. * log information
  177. * reset setting of resize
  178. """
  179. # log needed information
  180. if (self.iter + 1) % self.exp.print_interval == 0:
  181. # TODO check ETA logic
  182. left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
  183. eta_seconds = self.meter["iter_time"].global_avg * left_iters
  184. eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
  185. progress_str = "epoch: {}/{}, iter: {}/{}".format(
  186. self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
  187. )
  188. loss_meter = self.meter.get_filtered_meter("loss")
  189. loss_str = ", ".join(
  190. ["{}: {:.3f}".format(k, v.latest) for k, v in loss_meter.items()]
  191. )
  192. time_meter = self.meter.get_filtered_meter("time")
  193. time_str = ", ".join(
  194. ["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
  195. )
  196. logger.info(
  197. "{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format(
  198. progress_str,
  199. gpu_mem_usage(),
  200. time_str,
  201. loss_str,
  202. self.meter["lr"].latest,
  203. )
  204. + (", size: {:d}, {}".format(self.input_size[0], eta_str))
  205. )
  206. self.meter.clear_meters()
  207. # random resizing
  208. if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0:
  209. self.input_size = self.exp.random_resize(
  210. self.train_loader, self.epoch, self.rank, self.is_distributed
  211. )
  212. @property
  213. def progress_in_iter(self):
  214. return self.epoch * self.max_iter + self.iter
  215. def resume_train(self, model):
  216. if self.args.resume:
  217. logger.info("resume training")
  218. if self.args.ckpt is None:
  219. ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth.tar")
  220. else:
  221. ckpt_file = self.args.ckpt
  222. ckpt = torch.load(ckpt_file, map_location=self.device)
  223. # resume the model/optimizer state dict
  224. model.load_state_dict(ckpt["model"])
  225. self.optimizer.load_state_dict(ckpt["optimizer"])
  226. start_epoch = (
  227. self.args.start_epoch - 1
  228. if self.args.start_epoch is not None
  229. else ckpt["start_epoch"]
  230. )
  231. self.start_epoch = start_epoch
  232. logger.info(
  233. "loaded checkpoint '{}' (epoch {})".format(
  234. self.args.resume, self.start_epoch
  235. )
  236. ) # noqa
  237. else:
  238. if self.args.ckpt is not None:
  239. logger.info("loading checkpoint for fine tuning")
  240. ckpt_file = self.args.ckpt
  241. ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
  242. model = load_ckpt(model, ckpt)
  243. self.start_epoch = 0
  244. return model
  245. def evaluate_and_save_model(self):
  246. evalmodel = self.ema_model.ema if self.use_model_ema else self.model
  247. ap50_95, ap50, summary = self.exp.eval(
  248. evalmodel, self.evaluator, self.is_distributed
  249. )
  250. self.model.train()
  251. if self.rank == 0:
  252. self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
  253. self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
  254. logger.info("\n" + summary)
  255. synchronize()
  256. #self.best_ap = max(self.best_ap, ap50_95)
  257. self.save_ckpt("last_epoch", ap50 > self.best_ap)
  258. self.best_ap = max(self.best_ap, ap50)
  259. def save_ckpt(self, ckpt_name, update_best_ckpt=False):
  260. if self.rank == 0:
  261. save_model = self.ema_model.ema if self.use_model_ema else self.model
  262. logger.info("Save weights to {}".format(self.file_name))
  263. ckpt_state = {
  264. "start_epoch": self.epoch + 1,
  265. "model": save_model.state_dict(),
  266. "optimizer": self.optimizer.state_dict(),
  267. }
  268. save_checkpoint(
  269. ckpt_state,
  270. update_best_ckpt,
  271. self.file_name,
  272. ckpt_name,
  273. )