123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
-
- import torch
- from torch import distributed as dist
- from torch import nn
-
- import pickle
- from collections import OrderedDict
-
- from .dist import _get_global_gloo_group, get_world_size
-
- ASYNC_NORM = (
- nn.BatchNorm1d,
- nn.BatchNorm2d,
- nn.BatchNorm3d,
- nn.InstanceNorm1d,
- nn.InstanceNorm2d,
- nn.InstanceNorm3d,
- )
-
- __all__ = [
- "get_async_norm_states",
- "pyobj2tensor",
- "tensor2pyobj",
- "all_reduce",
- "all_reduce_norm",
- ]
-
-
- def get_async_norm_states(module):
- async_norm_states = OrderedDict()
- for name, child in module.named_modules():
- if isinstance(child, ASYNC_NORM):
- for k, v in child.state_dict().items():
- async_norm_states[".".join([name, k])] = v
- return async_norm_states
-
-
- def pyobj2tensor(pyobj, device="cuda"):
- """serialize picklable python object to tensor"""
- storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
- return torch.ByteTensor(storage).to(device=device)
-
-
- def tensor2pyobj(tensor):
- """deserialize tensor to picklable python object"""
- return pickle.loads(tensor.cpu().numpy().tobytes())
-
-
- def _get_reduce_op(op_name):
- return {
- "sum": dist.ReduceOp.SUM,
- "mean": dist.ReduceOp.SUM,
- }[op_name.lower()]
-
-
- def all_reduce(py_dict, op="sum", group=None):
- """
- Apply all reduce function for python dict object.
- NOTE: make sure that every py_dict has the same keys and values are in the same shape.
-
- Args:
- py_dict (dict): dict to apply all reduce op.
- op (str): operator, could be "sum" or "mean".
- """
- world_size = get_world_size()
- if world_size == 1:
- return py_dict
- if group is None:
- group = _get_global_gloo_group()
- if dist.get_world_size(group) == 1:
- return py_dict
-
- # all reduce logic across different devices.
- py_key = list(py_dict.keys())
- py_key_tensor = pyobj2tensor(py_key)
- dist.broadcast(py_key_tensor, src=0)
- py_key = tensor2pyobj(py_key_tensor)
-
- tensor_shapes = [py_dict[k].shape for k in py_key]
- tensor_numels = [py_dict[k].numel() for k in py_key]
-
- flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
- dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
- if op == "mean":
- flatten_tensor /= world_size
-
- split_tensors = [
- x.reshape(shape)
- for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
- ]
- return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
-
-
- def all_reduce_norm(module):
- """
- All reduce norm statistics in different devices.
- """
- states = get_async_norm_states(module)
- states = all_reduce(states, op="mean")
- module.load_state_dict(states, strict=False)
|