allreduce_norm.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. from torch import distributed as dist
  6. from torch import nn
  7. import pickle
  8. from collections import OrderedDict
  9. from .dist import _get_global_gloo_group, get_world_size
  10. ASYNC_NORM = (
  11. nn.BatchNorm1d,
  12. nn.BatchNorm2d,
  13. nn.BatchNorm3d,
  14. nn.InstanceNorm1d,
  15. nn.InstanceNorm2d,
  16. nn.InstanceNorm3d,
  17. )
  18. __all__ = [
  19. "get_async_norm_states",
  20. "pyobj2tensor",
  21. "tensor2pyobj",
  22. "all_reduce",
  23. "all_reduce_norm",
  24. ]
  25. def get_async_norm_states(module):
  26. async_norm_states = OrderedDict()
  27. for name, child in module.named_modules():
  28. if isinstance(child, ASYNC_NORM):
  29. for k, v in child.state_dict().items():
  30. async_norm_states[".".join([name, k])] = v
  31. return async_norm_states
  32. def pyobj2tensor(pyobj, device="cuda"):
  33. """serialize picklable python object to tensor"""
  34. storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
  35. return torch.ByteTensor(storage).to(device=device)
  36. def tensor2pyobj(tensor):
  37. """deserialize tensor to picklable python object"""
  38. return pickle.loads(tensor.cpu().numpy().tobytes())
  39. def _get_reduce_op(op_name):
  40. return {
  41. "sum": dist.ReduceOp.SUM,
  42. "mean": dist.ReduceOp.SUM,
  43. }[op_name.lower()]
  44. def all_reduce(py_dict, op="sum", group=None):
  45. """
  46. Apply all reduce function for python dict object.
  47. NOTE: make sure that every py_dict has the same keys and values are in the same shape.
  48. Args:
  49. py_dict (dict): dict to apply all reduce op.
  50. op (str): operator, could be "sum" or "mean".
  51. """
  52. world_size = get_world_size()
  53. if world_size == 1:
  54. return py_dict
  55. if group is None:
  56. group = _get_global_gloo_group()
  57. if dist.get_world_size(group) == 1:
  58. return py_dict
  59. # all reduce logic across different devices.
  60. py_key = list(py_dict.keys())
  61. py_key_tensor = pyobj2tensor(py_key)
  62. dist.broadcast(py_key_tensor, src=0)
  63. py_key = tensor2pyobj(py_key_tensor)
  64. tensor_shapes = [py_dict[k].shape for k in py_key]
  65. tensor_numels = [py_dict[k].numel() for k in py_key]
  66. flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
  67. dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
  68. if op == "mean":
  69. flatten_tensor /= world_size
  70. split_tensors = [
  71. x.reshape(shape)
  72. for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
  73. ]
  74. return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
  75. def all_reduce_norm(module):
  76. """
  77. All reduce norm statistics in different devices.
  78. """
  79. states = get_async_norm_states(module)
  80. states = all_reduce(states, op="mean")
  81. module.load_state_dict(states, strict=False)