123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- # Mahdi Abdollahpour, 27/11/2021, 02:51 PM, PyCharm, ByteTrack
-
-
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
-
- import torch
- from torch.nn import Module
-
- from yolox.utils import LRScheduler
-
- import ast
- import pprint
- from abc import ABCMeta, abstractmethod
- from tabulate import tabulate
- from typing import Dict
-
-
- class BaseMetaExp(metaclass=ABCMeta):
- """Basic class for any experiment."""
-
- def __init__(self):
- self.seed = None
- self.output_dir = "./YOLOX_outputs"
- self.print_interval = 100
- self.eval_interval = 10
-
- @abstractmethod
- def get_model(self) -> Module:
- pass
-
- @abstractmethod
- def get_data_loaders(
- self, batch_size: int, is_distributed: bool
- ) -> Dict[str, torch.utils.data.DataLoader]:
- pass
-
- @abstractmethod
- def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
- pass
-
- @abstractmethod
- def get_lr_scheduler(
- self, lr: float, iters_per_epoch: int, **kwargs
- ) -> LRScheduler:
- pass
-
- @abstractmethod
- def get_evaluator(self):
- pass
-
- @abstractmethod
- def eval(self, model, evaluator, weights):
- pass
-
- def __repr__(self):
- table_header = ["keys", "values"]
- exp_table = [
- (str(k), pprint.pformat(v))
- for k, v in vars(self).items()
- if not k.startswith("_")
- ]
- return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
-
- def merge(self, cfg_list):
- assert len(cfg_list) % 2 == 0
- for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
- # only update value with same key
- if hasattr(self, k):
- src_value = getattr(self, k)
- src_type = type(src_value)
- if src_value is not None and src_type != type(v):
- try:
- v = src_type(v)
- except Exception:
- v = ast.literal_eval(v)
- setattr(self, k, v)
|