Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_meta_exp.py 2.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Mahdi Abdollahpour, 27/11/2021, 02:51 PM, PyCharm, ByteTrack
  2. #!/usr/bin/env python3
  3. # -*- coding:utf-8 -*-
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. import torch
  6. from torch.nn import Module
  7. from yolox.utils import LRScheduler
  8. import ast
  9. import pprint
  10. from abc import ABCMeta, abstractmethod
  11. from tabulate import tabulate
  12. from typing import Dict
  13. class BaseMetaExp(metaclass=ABCMeta):
  14. """Basic class for any experiment."""
  15. def __init__(self):
  16. self.seed = None
  17. # self.output_dir = "./YOLOX_outputs"
  18. self.output_dir = "./meta_experiments"
  19. self.print_interval = 100
  20. self.eval_interval = 10
  21. @abstractmethod
  22. def get_model(self) -> Module:
  23. pass
  24. @abstractmethod
  25. def get_data_loaders(
  26. self, batch_size: int, is_distributed: bool
  27. ) -> Dict[str, torch.utils.data.DataLoader]:
  28. pass
  29. @abstractmethod
  30. def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
  31. pass
  32. @abstractmethod
  33. def get_lr_scheduler(
  34. self, lr: float, iters_per_epoch: int, **kwargs
  35. ) -> LRScheduler:
  36. pass
  37. @abstractmethod
  38. def get_evaluator(self):
  39. pass
  40. @abstractmethod
  41. def eval(self, model, evaluator, weights):
  42. pass
  43. def __repr__(self):
  44. table_header = ["keys", "values"]
  45. exp_table = [
  46. (str(k), pprint.pformat(v))
  47. for k, v in vars(self).items()
  48. if not k.startswith("_")
  49. ]
  50. return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
  51. def merge(self, cfg_list):
  52. assert len(cfg_list) % 2 == 0
  53. for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
  54. # only update value with same key
  55. if hasattr(self, k):
  56. src_value = getattr(self, k)
  57. src_type = type(src_value)
  58. if src_value is not None and src_type != type(v):
  59. try:
  60. v = src_type(v)
  61. except Exception:
  62. v = ast.literal_eval(v)
  63. setattr(self, k, v)