Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_meta_exp.py 2.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Mahdi Abdollahpour, 27/11/2021, 02:51 PM, PyCharm, ByteTrack
  2. #!/usr/bin/env python3
  3. # -*- coding:utf-8 -*-
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. import torch
  6. from torch.nn import Module
  7. from yolox.utils import LRScheduler
  8. import ast
  9. import pprint
  10. from abc import ABCMeta, abstractmethod
  11. from tabulate import tabulate
  12. from typing import Dict
  13. class BaseMetaExp(metaclass=ABCMeta):
  14. """Basic class for any experiment."""
  15. def __init__(self):
  16. self.seed = None
  17. self.output_dir = "./YOLOX_outputs"
  18. self.print_interval = 100
  19. self.eval_interval = 10
  20. @abstractmethod
  21. def get_model(self) -> Module:
  22. pass
  23. @abstractmethod
  24. def get_data_loaders(
  25. self, batch_size: int, is_distributed: bool
  26. ) -> Dict[str, torch.utils.data.DataLoader]:
  27. pass
  28. @abstractmethod
  29. def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
  30. pass
  31. @abstractmethod
  32. def get_lr_scheduler(
  33. self, lr: float, iters_per_epoch: int, **kwargs
  34. ) -> LRScheduler:
  35. pass
  36. @abstractmethod
  37. def get_evaluator(self):
  38. pass
  39. @abstractmethod
  40. def eval(self, model, evaluator, weights):
  41. pass
  42. def __repr__(self):
  43. table_header = ["keys", "values"]
  44. exp_table = [
  45. (str(k), pprint.pformat(v))
  46. for k, v in vars(self).items()
  47. if not k.startswith("_")
  48. ]
  49. return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
  50. def merge(self, cfg_list):
  51. assert len(cfg_list) % 2 == 0
  52. for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
  53. # only update value with same key
  54. if hasattr(self, k):
  55. src_value = getattr(self, k)
  56. src_type = type(src_value)
  57. if src_value is not None and src_type != type(v):
  58. try:
  59. v = src_type(v)
  60. except Exception:
  61. v = ast.literal_eval(v)
  62. setattr(self, k, v)