|
1234567891011121314151617 |
- import torch
- import torch.nn as nn
-
- def disable_running_stats(model):
- def _disable(module):
- if isinstance(module, nn.BatchNorm2d):
- module.backup_momentum = module.momentum
- module.momentum = 0
-
- model.apply(_disable)
-
- def enable_running_stats(model):
- def _enable(module):
- if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
- module.momentum = module.backup_momentum
-
- model.apply(_enable)
|