ofa.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import paddle
  5. import paddle.nn as nn
  6. import paddle.nn.functional as F
  7. from ppdet.core.workspace import load_config, merge_config, create
  8. from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  9. from ppdet.utils.logger import setup_logger
  10. from ppdet.core.workspace import register, serializable
  11. from paddle.utils import try_import
  12. logger = setup_logger(__name__)
  13. @register
  14. @serializable
  15. class OFA(object):
  16. def __init__(self, ofa_config):
  17. super(OFA, self).__init__()
  18. self.ofa_config = ofa_config
  19. def __call__(self, model, param_state_dict):
  20. paddleslim = try_import('paddleslim')
  21. from paddleslim.nas.ofa import OFA, RunConfig, utils
  22. from paddleslim.nas.ofa.convert_super import Convert, supernet
  23. task = self.ofa_config['task']
  24. expand_ratio = self.ofa_config['expand_ratio']
  25. skip_neck = self.ofa_config['skip_neck']
  26. skip_head = self.ofa_config['skip_head']
  27. run_config = self.ofa_config['RunConfig']
  28. if 'skip_layers' in run_config:
  29. skip_layers = run_config['skip_layers']
  30. else:
  31. skip_layers = []
  32. # supernet config
  33. sp_config = supernet(expand_ratio=expand_ratio)
  34. # convert to supernet
  35. model = Convert(sp_config).convert(model)
  36. skip_names = []
  37. if skip_neck:
  38. skip_names.append('neck.')
  39. if skip_head:
  40. skip_names.append('head.')
  41. for name, sublayer in model.named_sublayers():
  42. for n in skip_names:
  43. if n in name:
  44. skip_layers.append(name)
  45. run_config['skip_layers'] = skip_layers
  46. run_config = RunConfig(**run_config)
  47. # build ofa model
  48. ofa_model = OFA(model, run_config=run_config)
  49. ofa_model.set_epoch(0)
  50. ofa_model.set_task(task)
  51. input_spec = [{
  52. "image": paddle.ones(
  53. shape=[1, 3, 640, 640], dtype='float32'),
  54. "im_shape": paddle.full(
  55. [1, 2], 640, dtype='float32'),
  56. "scale_factor": paddle.ones(
  57. shape=[1, 2], dtype='float32')
  58. }]
  59. ofa_model._clear_search_space(input_spec=input_spec)
  60. ofa_model._build_ss = True
  61. check_ss = ofa_model._sample_config('expand_ratio', phase=None)
  62. # tokenize the search space
  63. ofa_model.tokenize()
  64. # check token map, search cands and search space
  65. logger.info('Token map is {}'.format(ofa_model.token_map))
  66. logger.info('Search candidates is {}'.format(ofa_model.search_cands))
  67. logger.info('The length of search_space is {}, search_space is {}'.
  68. format(len(ofa_model._ofa_layers), ofa_model._ofa_layers))
  69. # set model state dict into ofa model
  70. utils.set_state_dict(ofa_model.model, param_state_dict)
  71. return ofa_model