Meta Byte Track
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ema.py 2.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. import math
  7. from copy import deepcopy
  8. def is_parallel(model):
  9. """check if model is in parallel mode."""
  10. parallel_type = (
  11. nn.parallel.DataParallel,
  12. nn.parallel.DistributedDataParallel,
  13. )
  14. return isinstance(model, parallel_type)
  15. def copy_attr(a, b, include=(), exclude=()):
  16. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  17. for k, v in b.__dict__.items():
  18. if (len(include) and k not in include) or k.startswith("_") or k in exclude:
  19. continue
  20. else:
  21. setattr(a, k, v)
  22. class ModelEMA:
  23. """
  24. Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
  25. Keep a moving average of everything in the model state_dict (parameters and buffers).
  26. This is intended to allow functionality like
  27. https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  28. A smoothed version of the weights is necessary for some training schemes to perform well.
  29. This class is sensitive where it is initialized in the sequence of model init,
  30. GPU assignment and distributed training wrappers.
  31. """
  32. def __init__(self, model, decay=0.9999, updates=0):
  33. """
  34. Args:
  35. model (nn.Module): model to apply EMA.
  36. decay (float): ema decay reate.
  37. updates (int): counter of EMA updates.
  38. """
  39. # Create EMA(FP32)
  40. self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
  41. self.updates = updates
  42. # decay exponential ramp (to help early epochs)
  43. self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
  44. for p in self.ema.parameters():
  45. p.requires_grad_(False)
  46. def update(self, model):
  47. # Update EMA parameters
  48. with torch.no_grad():
  49. self.updates += 1
  50. d = self.decay(self.updates)
  51. msd = (
  52. model.module.state_dict() if is_parallel(model) else model.state_dict()
  53. ) # model state_dict
  54. for k, v in self.ema.state_dict().items():
  55. if v.dtype.is_floating_point:
  56. v *= d
  57. v += (1.0 - d) * msd[k].detach()
  58. def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
  59. # Update EMA attributes
  60. copy_attr(self.ema, model, include, exclude)