# Mahdi Abdollahpour, 27/11/2021, 02:35 PM, PyCharm, ByteTrack from loguru import logger import torch from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from yolox.data import DataPrefetcher from yolox.utils import ( MeterBuffer, ModelEMA, all_reduce_norm, get_model_info, get_rank, get_world_size, gpu_mem_usage, load_ckpt, occupy_mem, save_checkpoint, setup_logger, synchronize ) import datetime import os import time import learn2learn as l2l class MetaTrainer: def __init__(self, exp, args): # init function only defines some basic attr, other attrs like model, optimizer are built in # before_train methods. self.exp = exp self.args = args # training related attr self.max_epoch = exp.max_epoch self.amp_training = args.fp16 self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) self.is_distributed = get_world_size() > 1 self.rank = get_rank() self.local_rank = args.local_rank self.device = "cuda:{}".format(self.local_rank) self.use_model_ema = exp.ema # data/dataloader related attr self.data_type = torch.float16 if args.fp16 else torch.float32 self.input_size = exp.input_size self.best_ap = 0 self.adaptation_period = args.adaptation_period # metric record self.meter = MeterBuffer(window_size=exp.print_interval) self.file_name = os.path.join(exp.output_dir, args.experiment_name) if self.rank == 0: os.makedirs(self.file_name, exist_ok=True) setup_logger( self.file_name, distributed_rank=self.rank, filename="train_log.txt", mode="a", ) def train(self): self.before_train() try: self.train_in_epoch() except Exception: raise finally: self.after_train() def train_in_epoch(self): # self.evaluate_and_save_model() for self.epoch in range(self.start_epoch, self.max_epoch): self.before_epoch() self.train_in_task() self.after_epoch() def train_in_iter(self): for self.iter in range(len(self.train_loader)): self.before_iter() self.train_one_iter() self.after_iter() def train_in_task(self): for i in range(len(self.train_loaders)): self.before_task(i) self.train_in_iter() self.after_task(i) def before_task(self, i): # logger.info("init prefetcher, this might take one minute or less...") self.train_loader = self.train_loaders[i] self.prefetcher = self.prefetchers[i] self.learner = self.model.clone() logger.info("model clone created. dataloader:{}".format(i)) def after_task(self,i): self.epoch_progress += len(self.train_loaders[i]) def adapt(self, inps, targets): # adapt_inps =inps[:1, ...] # targets_inps =targets[:1, ...] # print(adapt_inps.shape) # print(targets_inps.shape) outputs = self.learner(inps, targets) loss = outputs["total_loss"] self.learner.adapt(loss) del outputs, loss def train_one_iter(self): iter_start_time = time.time() inps, targets = self.prefetcher.next() inps = inps.to(self.data_type) targets = targets.to(self.data_type) targets.requires_grad = False data_end_time = time.time() # logger.info("in train iter") with torch.cuda.amp.autocast(enabled=self.amp_training): if self.iter % self.adaptation_period == 0: self.adapt(inps, targets) outputs = self.learner(inps, targets) loss = outputs["total_loss"] for p in self.exp.all_parameters: if p.grad is not None: p.grad.data.mul_(1.0 / self.args.batch_size) self.optimizer.zero_grad() # logger.info("loss Norm: {} , scale {}".format(torch.norm(loss), self.scaler.get_scale())) # loss = self.scaler.scale(loss) # logger.info("loss Norm: {} , scale {}".format(torch.norm(loss), self.scaler.get_scale())) self.scaler.scale(loss).backward() # loss.backward() self.scaler.step(self.optimizer) self.scaler.update() if self.use_model_ema: self.ema_model.update(self.model) lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1) for param_group in self.optimizer.param_groups: param_group["lr"] = lr iter_end_time = time.time() self.meter.update( iter_time=iter_end_time - iter_start_time, data_time=data_end_time - iter_start_time, lr=lr, **outputs, ) del loss, outputs def before_train(self): logger.info("args: {}".format(self.args)) # logger.info("exp value:\n{}".format(self.exp)) # model related init torch.cuda.set_device(self.local_rank) model = self.exp.get_model() logger.info( "Model Summary: {}".format(get_model_info(model, self.exp.test_size)) ) # exit() model.to(self.device) # from torchsummary import summary # summary(model, input_size=(3, 300, 300), device='cuda') # value of epoch will be set in `resume_train` self.model = l2l.algorithms.MAML(model, lr=self.exp.inner_lr, first_order=self.exp.first_order) # solver related init self.optimizer = self.exp.get_optimizer(self.args.batch_size) self.model = self.resume_train(self.model) # data related init self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs logger.info('Getting data loaders') self.train_loaders = self.exp.get_data_loaders( batch_size=self.args.batch_size, is_distributed=self.is_distributed, no_aug=self.no_aug, ) # max_iter means iters per epoch self.max_iter = 0 self.prefetchers = [] for i, train_loader in enumerate(self.train_loaders): self.max_iter += len(train_loader) self.prefetchers.append(DataPrefetcher(train_loader)) self.lr_scheduler = self.exp.get_lr_scheduler( self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter ) if self.args.occupy: occupy_mem(self.local_rank) if self.is_distributed: self.model = DDP(self.model, device_ids=[self.local_rank], broadcast_buffers=False) if self.use_model_ema: self.ema_model = ModelEMA(self.model, 0.9998) self.ema_model.updates = self.max_iter * self.start_epoch # self.model = model self.model.train() self.evaluators = self.exp.get_evaluators( batch_size=self.args.batch_size, is_distributed=self.is_distributed ) # Tensorboard logger if self.rank == 0: self.tblogger = SummaryWriter(self.file_name) logger.info("Training start...") # logger.info("\n{}".format(model)) def after_train(self): logger.info( "Training of experiment is done and the best AP is {:.2f}".format( self.best_ap * 100 ) ) def before_epoch(self): logger.info("---> start train epoch{}".format(self.epoch + 1)) self.epoch_progress = 0 if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug: logger.info("--->No mosaic aug now!") for train_loader in self.train_loaders: train_loader.close_mosaic() logger.info("--->Add additional L1 loss now!") if self.is_distributed: self.model.module.head.use_l1 = True else: self.model.head.use_l1 = True self.exp.eval_interval = 1 if not self.no_aug: self.save_ckpt(ckpt_name="last_mosaic_epoch") def after_epoch(self): self.epoch_progress = 0 if self.use_model_ema: self.ema_model.update_attr(self.model) self.save_ckpt(ckpt_name="latest") if (self.epoch + 1) % self.exp.eval_interval == 0: all_reduce_norm(self.model) self.evaluate_and_save_model() def before_iter(self): pass def after_iter(self): """ `after_iter` contains two parts of logic: * log information * reset setting of resize """ # log needed information if (self.iter + 1) % self.exp.print_interval == 0: # TODO check ETA logic left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1) eta_seconds = self.meter["iter_time"].global_avg * left_iters eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds))) progress_str = "epoch: {}/{}, iter: {}/{}".format( self.epoch + 1, self.max_epoch, self.iter + 1 + self.epoch_progress, self.max_iter ) loss_meter = self.meter.get_filtered_meter("loss") loss_str = ", ".join( ["{}: {:.3f}".format(k, v.latest) for k, v in loss_meter.items()] ) time_meter = self.meter.get_filtered_meter("time") time_str = ", ".join( ["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()] ) logger.info( "{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format( progress_str, gpu_mem_usage(), time_str, loss_str, self.meter["lr"].latest, ) + (", size: {:d}, {}".format(self.input_size[0], eta_str)) ) self.meter.clear_meters() # random resizing if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0: self.input_size = self.exp.random_resize( self.train_loader, self.epoch, self.rank, self.is_distributed ) @property def progress_in_iter(self): return self.epoch * self.max_iter + self.iter + self.epoch_progress def resume_train(self, model): if self.args.resume: logger.info("resume training") if self.args.ckpt is None: ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth.tar") else: ckpt_file = self.args.ckpt ckpt = torch.load(ckpt_file, map_location=self.device) # TODO: handle pretrained BYTETrack # handling meta models # new_dict = {} # for key in ckpt["model"].keys(): # if key.startswith('module.'): # new_dict[key[7:]] = ckpt["model"][key] # else: # new_dict[key] = ckpt["model"][key] # del ckpt["model"] # ckpt["model"] = new_dict # resume the model/optimizer state dict model.load_state_dict(ckpt["model"]) self.optimizer.load_state_dict(ckpt["optimizer"]) start_epoch = ( self.args.start_epoch - 1 if self.args.start_epoch is not None else ckpt["start_epoch"] ) self.start_epoch = start_epoch logger.info( "loaded checkpoint '{}' (epoch {})".format( self.args.resume, self.start_epoch ) ) # noqa else: if self.args.ckpt is not None: logger.info("loading checkpoint for fine tuning") ckpt_file = self.args.ckpt ckpt = torch.load(ckpt_file, map_location=self.device)["model"] model = load_ckpt(model, ckpt) self.start_epoch = 0 return model def evaluate_and_save_model(self): logger.info("starting eval...") evalmodel = self.ema_model.ema if self.use_model_ema else self.model ap50_95, ap50, summary = self.exp.eval( evalmodel, self.evaluators, self.is_distributed ) self.model.train() if self.rank == 0: self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1) self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1) logger.info("\n" + summary) synchronize() # self.best_ap = max(self.best_ap, ap50_95) self.save_ckpt("last_epoch", ap50 > self.best_ap) self.best_ap = max(self.best_ap, ap50) def save_ckpt(self, ckpt_name, update_best_ckpt=False): if self.rank == 0: save_model = self.ema_model.ema if self.use_model_ema else self.model logger.info("Save weights to {}".format(self.file_name)) ckpt_state = { "start_epoch": self.epoch + 1, "model": save_model.state_dict(), "optimizer": self.optimizer.state_dict(), } save_checkpoint( ckpt_state, update_best_ckpt, self.file_name, ckpt_name, )