model_utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. import torch
  5. import torch.nn as nn
  6. from thop import profile
  7. from copy import deepcopy
  8. __all__ = [
  9. "fuse_conv_and_bn",
  10. "fuse_model",
  11. "get_model_info",
  12. "replace_module",
  13. ]
  14. def get_model_info(model, tsize):
  15. stride = 64
  16. img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
  17. flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
  18. params /= 1e6
  19. flops /= 1e9
  20. flops *= tsize[0] * tsize[1] / stride / stride * 2 # Gflops
  21. info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
  22. return info
  23. def fuse_conv_and_bn(conv, bn):
  24. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  25. fusedconv = (
  26. nn.Conv2d(
  27. conv.in_channels,
  28. conv.out_channels,
  29. kernel_size=conv.kernel_size,
  30. stride=conv.stride,
  31. padding=conv.padding,
  32. groups=conv.groups,
  33. bias=True,
  34. )
  35. .requires_grad_(False)
  36. .to(conv.weight.device)
  37. )
  38. # prepare filters
  39. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  40. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  41. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  42. # prepare spatial bias
  43. b_conv = (
  44. torch.zeros(conv.weight.size(0), device=conv.weight.device)
  45. if conv.bias is None
  46. else conv.bias
  47. )
  48. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
  49. torch.sqrt(bn.running_var + bn.eps)
  50. )
  51. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  52. return fusedconv
  53. def fuse_model(model):
  54. from yolox.models.network_blocks import BaseConv
  55. for m in model.modules():
  56. if type(m) is BaseConv and hasattr(m, "bn"):
  57. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  58. delattr(m, "bn") # remove batchnorm
  59. m.forward = m.fuseforward # update forward
  60. return model
  61. def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
  62. """
  63. Replace given type in module to a new type. mostly used in deploy.
  64. Args:
  65. module (nn.Module): model to apply replace operation.
  66. replaced_module_type (Type): module type to be replaced.
  67. new_module_type (Type)
  68. replace_func (function): python function to describe replace logic. Defalut value None.
  69. Returns:
  70. model (nn.Module): module that already been replaced.
  71. """
  72. def default_replace_func(replaced_module_type, new_module_type):
  73. return new_module_type()
  74. if replace_func is None:
  75. replace_func = default_replace_func
  76. model = module
  77. if isinstance(module, replaced_module_type):
  78. model = replace_func(replaced_module_type, new_module_type)
  79. else: # recurrsively replace
  80. for name, child in module.named_children():
  81. new_child = replace_module(child, replaced_module_type, new_module_type)
  82. if new_child is not child: # child is already replaced
  83. model.add_module(name, new_child)
  84. return model