1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from ppdet.core.workspace import load_config, merge_config, create
- from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
- from ppdet.utils.logger import setup_logger
- from ppdet.core.workspace import register, serializable
- from paddle.utils import try_import
- logger = setup_logger(__name__)
- @register
- @serializable
- class OFA(object):
- def __init__(self, ofa_config):
- super(OFA, self).__init__()
- self.ofa_config = ofa_config
- def __call__(self, model, param_state_dict):
- paddleslim = try_import('paddleslim')
- from paddleslim.nas.ofa import OFA, RunConfig, utils
- from paddleslim.nas.ofa.convert_super import Convert, supernet
- task = self.ofa_config['task']
- expand_ratio = self.ofa_config['expand_ratio']
- skip_neck = self.ofa_config['skip_neck']
- skip_head = self.ofa_config['skip_head']
- run_config = self.ofa_config['RunConfig']
- if 'skip_layers' in run_config:
- skip_layers = run_config['skip_layers']
- else:
- skip_layers = []
- # supernet config
- sp_config = supernet(expand_ratio=expand_ratio)
- # convert to supernet
- model = Convert(sp_config).convert(model)
- skip_names = []
- if skip_neck:
- skip_names.append('neck.')
- if skip_head:
- skip_names.append('head.')
- for name, sublayer in model.named_sublayers():
- for n in skip_names:
- if n in name:
- skip_layers.append(name)
- run_config['skip_layers'] = skip_layers
- run_config = RunConfig(**run_config)
- # build ofa model
- ofa_model = OFA(model, run_config=run_config)
- ofa_model.set_epoch(0)
- ofa_model.set_task(task)
- input_spec = [{
- "image": paddle.ones(
- shape=[1, 3, 640, 640], dtype='float32'),
- "im_shape": paddle.full(
- [1, 2], 640, dtype='float32'),
- "scale_factor": paddle.ones(
- shape=[1, 2], dtype='float32')
- }]
- ofa_model._clear_search_space(input_spec=input_spec)
- ofa_model._build_ss = True
- check_ss = ofa_model._sample_config('expand_ratio', phase=None)
- # tokenize the search space
- ofa_model.tokenize()
- # check token map, search cands and search space
- logger.info('Token map is {}'.format(ofa_model.token_map))
- logger.info('Search candidates is {}'.format(ofa_model.search_cands))
- logger.info('The length of search_space is {}, search_space is {}'.
- format(len(ofa_model._ofa_layers), ofa_model._ofa_layers))
- # set model state dict into ofa model
- utils.set_state_dict(ofa_model.model, param_state_dict)
- return ofa_model
|