# Mahdi Abdollahpour, 27/11/2021, 02:51 PM, PyCharm, ByteTrack #!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. import torch from torch.nn import Module from yolox.utils import LRScheduler import ast import pprint from abc import ABCMeta, abstractmethod from tabulate import tabulate from typing import Dict class BaseMetaExp(metaclass=ABCMeta): """Basic class for any experiment.""" def __init__(self): self.seed = None # self.output_dir = "./YOLOX_outputs" self.output_dir = "./meta_experiments" self.print_interval = 100 self.eval_interval = 10 @abstractmethod def get_model(self) -> Module: pass @abstractmethod def get_data_loaders( self, batch_size: int, is_distributed: bool ) -> Dict[str, torch.utils.data.DataLoader]: pass @abstractmethod def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer: pass @abstractmethod def get_lr_scheduler( self, lr: float, iters_per_epoch: int, **kwargs ) -> LRScheduler: pass @abstractmethod def get_evaluator(self): pass @abstractmethod def eval(self, model, evaluator, weights): pass def __repr__(self): table_header = ["keys", "values"] exp_table = [ (str(k), pprint.pformat(v)) for k, v in vars(self).items() if not k.startswith("_") ] return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid") def merge(self, cfg_list): assert len(cfg_list) % 2 == 0 for k, v in zip(cfg_list[0::2], cfg_list[1::2]): # only update value with same key if hasattr(self, k): src_value = getattr(self, k) src_type = type(src_value) if src_value is not None and src_type != type(v): try: v = src_type(v) except Exception: v = ast.literal_eval(v) setattr(self, k, v)