123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- import torch
- import torch.nn as nn
-
- import math
- from copy import deepcopy
-
-
- def is_parallel(model):
- """check if model is in parallel mode."""
-
- parallel_type = (
- nn.parallel.DataParallel,
- nn.parallel.DistributedDataParallel,
- )
- return isinstance(model, parallel_type)
-
-
- def copy_attr(a, b, include=(), exclude=()):
- # Copy attributes from b to a, options to only include [...] and to exclude [...]
- for k, v in b.__dict__.items():
- if (len(include) and k not in include) or k.startswith("_") or k in exclude:
- continue
- else:
- setattr(a, k, v)
-
-
- class ModelEMA:
- """
- Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
- Keep a moving average of everything in the model state_dict (parameters and buffers).
- This is intended to allow functionality like
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
- A smoothed version of the weights is necessary for some training schemes to perform well.
- This class is sensitive where it is initialized in the sequence of model init,
- GPU assignment and distributed training wrappers.
- """
-
- def __init__(self, model, decay=0.9999, updates=0):
- """
- Args:
- model (nn.Module): model to apply EMA.
- decay (float): ema decay reate.
- updates (int): counter of EMA updates.
- """
- # Create EMA(FP32)
- self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
- self.updates = updates
- # decay exponential ramp (to help early epochs)
- self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
- for p in self.ema.parameters():
- p.requires_grad_(False)
-
- def update(self, model):
- # Update EMA parameters
- with torch.no_grad():
- self.updates += 1
- d = self.decay(self.updates)
-
- msd = (
- model.module.state_dict() if is_parallel(model) else model.state_dict()
- ) # model state_dict
- for k, v in self.ema.state_dict().items():
- if v.dtype.is_floating_point:
- v *= d
- v += (1.0 - d) * msd[k].detach()
-
- def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
- # Update EMA attributes
- copy_attr(self.ema, model, include, exclude)
|