Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_exp.py 2.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. from torch.nn import Module
  6. from yolox.utils import LRScheduler
  7. import ast
  8. import pprint
  9. from abc import ABCMeta, abstractmethod
  10. from tabulate import tabulate
  11. from typing import Dict
  12. class BaseExp(metaclass=ABCMeta):
  13. """Basic class for any experiment."""
  14. def __init__(self):
  15. self.seed = None
  16. self.output_dir = "./YOLOX_outputs"
  17. self.print_interval = 100
  18. self.eval_interval = 10
  19. @abstractmethod
  20. def get_model(self) -> Module:
  21. pass
  22. @abstractmethod
  23. def get_data_loader(
  24. self, batch_size: int, is_distributed: bool
  25. ) -> Dict[str, torch.utils.data.DataLoader]:
  26. pass
  27. @abstractmethod
  28. def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
  29. pass
  30. @abstractmethod
  31. def get_lr_scheduler(
  32. self, lr: float, iters_per_epoch: int, **kwargs
  33. ) -> LRScheduler:
  34. pass
  35. @abstractmethod
  36. def get_evaluator(self):
  37. pass
  38. @abstractmethod
  39. def eval(self, model, evaluator, weights):
  40. pass
  41. def __repr__(self):
  42. table_header = ["keys", "values"]
  43. exp_table = [
  44. (str(k), pprint.pformat(v))
  45. for k, v in vars(self).items()
  46. if not k.startswith("_")
  47. ]
  48. return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
  49. def merge(self, cfg_list):
  50. assert len(cfg_list) % 2 == 0
  51. for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
  52. # only update value with same key
  53. if hasattr(self, k):
  54. src_value = getattr(self, k)
  55. src_type = type(src_value)
  56. if src_value is not None and src_type != type(v):
  57. try:
  58. v = src_type(v)
  59. except Exception:
  60. v = ast.literal_eval(v)
  61. setattr(self, k, v)