3
0

trainer.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import copy
  20. import time
  21. from tqdm import tqdm
  22. import numpy as np
  23. import typing
  24. from PIL import Image, ImageOps, ImageFile
  25. ImageFile.LOAD_TRUNCATED_IMAGES = True
  26. import paddle
  27. import paddle.nn as nn
  28. import paddle.distributed as dist
  29. from paddle.distributed import fleet
  30. from paddle import amp
  31. from paddle.static import InputSpec
  32. from ppdet.optimizer import ModelEMA
  33. from ppdet.core.workspace import create
  34. from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  35. from ppdet.utils.visualizer import visualize_results, save_result
  36. from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
  37. from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
  38. from ppdet.data.source.sniper_coco import SniperCOCODataSet
  39. from ppdet.data.source.category import get_categories
  40. import ppdet.utils.stats as stats
  41. from ppdet.utils import profiler
  42. from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator
  43. from .export_utils import _dump_infer_config, _prune_input_spec
  44. from ppdet.utils.logger import setup_logger
  45. logger = setup_logger('ppdet.engine')
  46. __all__ = ['Trainer']
  47. MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
  48. class Trainer(object):
  49. def __init__(self, cfg, mode='train'):
  50. self.cfg = cfg
  51. assert mode.lower() in ['train', 'eval', 'test'], \
  52. "mode should be 'train', 'eval' or 'test'"
  53. self.mode = mode.lower()
  54. self.optimizer = None
  55. self.is_loaded_weights = False
  56. # build data loader
  57. capital_mode = self.mode.capitalize()
  58. if cfg.architecture in MOT_ARCH and self.mode in ['eval', 'test']:
  59. self.dataset = self.cfg['{}MOTDataset'.format(
  60. capital_mode)] = create('{}MOTDataset'.format(capital_mode))()
  61. else:
  62. self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
  63. '{}Dataset'.format(capital_mode))()
  64. if cfg.architecture == 'DeepSORT' and self.mode == 'train':
  65. logger.error('DeepSORT has no need of training on mot dataset.')
  66. sys.exit(1)
  67. if cfg.architecture == 'FairMOT' and self.mode == 'eval':
  68. images = self.parse_mot_images(cfg)
  69. self.dataset.set_images(images)
  70. if self.mode == 'train':
  71. self.loader = create('{}Reader'.format(capital_mode))(
  72. self.dataset, cfg.worker_num)
  73. if cfg.architecture == 'JDE' and self.mode == 'train':
  74. cfg['JDEEmbeddingHead'][
  75. 'num_identities'] = self.dataset.num_identities_dict[0]
  76. # JDE only support single class MOT now.
  77. if cfg.architecture == 'FairMOT' and self.mode == 'train':
  78. cfg['FairMOTEmbeddingHead'][
  79. 'num_identities_dict'] = self.dataset.num_identities_dict
  80. # FairMOT support single class and multi-class MOT now.
  81. # build model
  82. if 'model' not in self.cfg:
  83. self.model = create(cfg.architecture)
  84. else:
  85. self.model = self.cfg.model
  86. self.is_loaded_weights = True
  87. if cfg.architecture == 'YOLOX':
  88. for k, m in self.model.named_sublayers():
  89. if isinstance(m, nn.BatchNorm2D):
  90. m._epsilon = 1e-3 # for amp(fp16)
  91. m._momentum = 0.97 # 0.03 in pytorch
  92. #normalize params for deploy
  93. if 'slim' in cfg and cfg['slim_type'] == 'OFA':
  94. self.model.model.load_meanstd(cfg['TestReader'][
  95. 'sample_transforms'])
  96. elif 'slim' in cfg and cfg['slim_type'] == 'Distill':
  97. self.model.student_model.load_meanstd(cfg['TestReader'][
  98. 'sample_transforms'])
  99. elif 'slim' in cfg and cfg[
  100. 'slim_type'] == 'DistillPrune' and self.mode == 'train':
  101. self.model.student_model.load_meanstd(cfg['TestReader'][
  102. 'sample_transforms'])
  103. else:
  104. self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
  105. self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
  106. if self.use_ema:
  107. ema_decay = self.cfg.get('ema_decay', 0.9998)
  108. cycle_epoch = self.cfg.get('cycle_epoch', -1)
  109. ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
  110. self.ema = ModelEMA(
  111. self.model,
  112. decay=ema_decay,
  113. ema_decay_type=ema_decay_type,
  114. cycle_epoch=cycle_epoch)
  115. # EvalDataset build with BatchSampler to evaluate in single device
  116. # TODO: multi-device evaluate
  117. if self.mode == 'eval':
  118. if cfg.architecture == 'FairMOT':
  119. self.loader = create('EvalMOTReader')(self.dataset, 0)
  120. else:
  121. self._eval_batch_sampler = paddle.io.BatchSampler(
  122. self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
  123. reader_name = '{}Reader'.format(self.mode.capitalize())
  124. # If metric is VOC, need to be set collate_batch=False.
  125. if cfg.metric == 'VOC':
  126. cfg[reader_name]['collate_batch'] = False
  127. self.loader = create(reader_name)(self.dataset, cfg.worker_num,
  128. self._eval_batch_sampler)
  129. # TestDataset build after user set images, skip loader creation here
  130. # build optimizer in train mode
  131. if self.mode == 'train':
  132. steps_per_epoch = len(self.loader)
  133. self.lr = create('LearningRate')(steps_per_epoch)
  134. self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
  135. # Unstructured pruner is only enabled in the train mode.
  136. if self.cfg.get('unstructured_prune'):
  137. self.pruner = create('UnstructuredPruner')(self.model,
  138. steps_per_epoch)
  139. self._nranks = dist.get_world_size()
  140. self._local_rank = dist.get_rank()
  141. self.status = {}
  142. self.start_epoch = 0
  143. self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
  144. # initial default callbacks
  145. self._init_callbacks()
  146. # initial default metrics
  147. self._init_metrics()
  148. self._reset_metrics()
  149. def _init_callbacks(self):
  150. if self.mode == 'train':
  151. self._callbacks = [LogPrinter(self), Checkpointer(self)]
  152. if self.cfg.get('use_vdl', False):
  153. self._callbacks.append(VisualDLWriter(self))
  154. if self.cfg.get('save_proposals', False):
  155. self._callbacks.append(SniperProposalsGenerator(self))
  156. self._compose_callback = ComposeCallback(self._callbacks)
  157. elif self.mode == 'eval':
  158. self._callbacks = [LogPrinter(self)]
  159. if self.cfg.metric == 'WiderFace':
  160. self._callbacks.append(WiferFaceEval(self))
  161. self._compose_callback = ComposeCallback(self._callbacks)
  162. elif self.mode == 'test' and self.cfg.get('use_vdl', False):
  163. self._callbacks = [VisualDLWriter(self)]
  164. self._compose_callback = ComposeCallback(self._callbacks)
  165. else:
  166. self._callbacks = []
  167. self._compose_callback = None
  168. def _init_metrics(self, validate=False):
  169. if self.mode == 'test' or (self.mode == 'train' and not validate):
  170. self._metrics = []
  171. return
  172. classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
  173. if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
  174. # TODO: bias should be unified
  175. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  176. output_eval = self.cfg['output_eval'] \
  177. if 'output_eval' in self.cfg else None
  178. save_prediction_only = self.cfg.get('save_prediction_only', False)
  179. # pass clsid2catid info to metric instance to avoid multiple loading
  180. # annotation file
  181. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  182. if self.mode == 'eval' else None
  183. # when do validation in train, annotation file should be get from
  184. # EvalReader instead of self.dataset(which is TrainReader)
  185. anno_file = self.dataset.get_anno()
  186. dataset = self.dataset
  187. if self.mode == 'train' and validate:
  188. eval_dataset = self.cfg['EvalDataset']
  189. eval_dataset.check_or_download_dataset()
  190. anno_file = eval_dataset.get_anno()
  191. dataset = eval_dataset
  192. IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
  193. if self.cfg.metric == "COCO":
  194. self._metrics = [
  195. COCOMetric(
  196. anno_file=anno_file,
  197. clsid2catid=clsid2catid,
  198. classwise=classwise,
  199. output_eval=output_eval,
  200. bias=bias,
  201. IouType=IouType,
  202. save_prediction_only=save_prediction_only)
  203. ]
  204. elif self.cfg.metric == "SNIPERCOCO": # sniper
  205. self._metrics = [
  206. SNIPERCOCOMetric(
  207. anno_file=anno_file,
  208. dataset=dataset,
  209. clsid2catid=clsid2catid,
  210. classwise=classwise,
  211. output_eval=output_eval,
  212. bias=bias,
  213. IouType=IouType,
  214. save_prediction_only=save_prediction_only)
  215. ]
  216. elif self.cfg.metric == 'RBOX':
  217. # TODO: bias should be unified
  218. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  219. output_eval = self.cfg['output_eval'] \
  220. if 'output_eval' in self.cfg else None
  221. save_prediction_only = self.cfg.get('save_prediction_only', False)
  222. # pass clsid2catid info to metric instance to avoid multiple loading
  223. # annotation file
  224. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  225. if self.mode == 'eval' else None
  226. # when do validation in train, annotation file should be get from
  227. # EvalReader instead of self.dataset(which is TrainReader)
  228. anno_file = self.dataset.get_anno()
  229. if self.mode == 'train' and validate:
  230. eval_dataset = self.cfg['EvalDataset']
  231. eval_dataset.check_or_download_dataset()
  232. anno_file = eval_dataset.get_anno()
  233. self._metrics = [
  234. RBoxMetric(
  235. anno_file=anno_file,
  236. clsid2catid=clsid2catid,
  237. classwise=classwise,
  238. output_eval=output_eval,
  239. bias=bias,
  240. save_prediction_only=save_prediction_only)
  241. ]
  242. elif self.cfg.metric == 'VOC':
  243. self._metrics = [
  244. VOCMetric(
  245. label_list=self.dataset.get_label_list(),
  246. class_num=self.cfg.num_classes,
  247. map_type=self.cfg.map_type,
  248. classwise=classwise)
  249. ]
  250. elif self.cfg.metric == 'WiderFace':
  251. multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
  252. self._metrics = [
  253. WiderFaceMetric(
  254. image_dir=os.path.join(self.dataset.dataset_dir,
  255. self.dataset.image_dir),
  256. anno_file=self.dataset.get_anno(),
  257. multi_scale=multi_scale)
  258. ]
  259. elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
  260. eval_dataset = self.cfg['EvalDataset']
  261. eval_dataset.check_or_download_dataset()
  262. anno_file = eval_dataset.get_anno()
  263. save_prediction_only = self.cfg.get('save_prediction_only', False)
  264. self._metrics = [
  265. KeyPointTopDownCOCOEval(
  266. anno_file,
  267. len(eval_dataset),
  268. self.cfg.num_joints,
  269. self.cfg.save_dir,
  270. save_prediction_only=save_prediction_only)
  271. ]
  272. elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
  273. eval_dataset = self.cfg['EvalDataset']
  274. eval_dataset.check_or_download_dataset()
  275. anno_file = eval_dataset.get_anno()
  276. save_prediction_only = self.cfg.get('save_prediction_only', False)
  277. self._metrics = [
  278. KeyPointTopDownMPIIEval(
  279. anno_file,
  280. len(eval_dataset),
  281. self.cfg.num_joints,
  282. self.cfg.save_dir,
  283. save_prediction_only=save_prediction_only)
  284. ]
  285. elif self.cfg.metric == 'MOTDet':
  286. self._metrics = [JDEDetMetric(), ]
  287. else:
  288. logger.warning("Metric not support for metric type {}".format(
  289. self.cfg.metric))
  290. self._metrics = []
  291. def _reset_metrics(self):
  292. for metric in self._metrics:
  293. metric.reset()
  294. def register_callbacks(self, callbacks):
  295. callbacks = [c for c in list(callbacks) if c is not None]
  296. for c in callbacks:
  297. assert isinstance(c, Callback), \
  298. "metrics shoule be instances of subclass of Metric"
  299. self._callbacks.extend(callbacks)
  300. self._compose_callback = ComposeCallback(self._callbacks)
  301. def register_metrics(self, metrics):
  302. metrics = [m for m in list(metrics) if m is not None]
  303. for m in metrics:
  304. assert isinstance(m, Metric), \
  305. "metrics shoule be instances of subclass of Metric"
  306. self._metrics.extend(metrics)
  307. def load_weights(self, weights):
  308. if self.is_loaded_weights:
  309. return
  310. self.start_epoch = 0
  311. load_pretrain_weight(self.model, weights)
  312. logger.debug("Load weights {} to start training".format(weights))
  313. def load_weights_sde(self, det_weights, reid_weights):
  314. if self.model.detector:
  315. load_weight(self.model.detector, det_weights)
  316. load_weight(self.model.reid, reid_weights)
  317. else:
  318. load_weight(self.model.reid, reid_weights)
  319. def resume_weights(self, weights):
  320. # support Distill resume weights
  321. if hasattr(self.model, 'student_model'):
  322. self.start_epoch = load_weight(self.model.student_model, weights,
  323. self.optimizer)
  324. else:
  325. self.start_epoch = load_weight(self.model, weights, self.optimizer,
  326. self.ema if self.use_ema else None)
  327. logger.debug("Resume weights of epoch {}".format(self.start_epoch))
  328. def train(self, validate=False):
  329. assert self.mode == 'train', "Model not in 'train' mode"
  330. Init_mark = False
  331. if validate:
  332. self.cfg.EvalDataset = create("EvalDataset")()
  333. sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
  334. self.cfg.use_gpu and self._nranks > 1)
  335. if sync_bn:
  336. self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
  337. self.model)
  338. model = self.model
  339. if self.cfg.get('fleet', False):
  340. model = fleet.distributed_model(model)
  341. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  342. elif self._nranks > 1:
  343. find_unused_parameters = self.cfg[
  344. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  345. model = paddle.DataParallel(
  346. self.model, find_unused_parameters=find_unused_parameters)
  347. # enabel auto mixed precision mode
  348. if self.cfg.get('amp', False):
  349. scaler = amp.GradScaler(
  350. enable=self.cfg.use_gpu or self.cfg.use_npu,
  351. init_loss_scaling=1024)
  352. self.status.update({
  353. 'epoch_id': self.start_epoch,
  354. 'step_id': 0,
  355. 'steps_per_epoch': len(self.loader)
  356. })
  357. self.status['batch_time'] = stats.SmoothedValue(
  358. self.cfg.log_iter, fmt='{avg:.4f}')
  359. self.status['data_time'] = stats.SmoothedValue(
  360. self.cfg.log_iter, fmt='{avg:.4f}')
  361. self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
  362. if self.cfg.get('print_flops', False):
  363. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  364. self.dataset, self.cfg.worker_num)
  365. self._flops(flops_loader)
  366. profiler_options = self.cfg.get('profiler_options', None)
  367. self._compose_callback.on_train_begin(self.status)
  368. for epoch_id in range(self.start_epoch, self.cfg.epoch):
  369. self.status['mode'] = 'train'
  370. self.status['epoch_id'] = epoch_id
  371. self._compose_callback.on_epoch_begin(self.status)
  372. self.loader.dataset.set_epoch(epoch_id)
  373. model.train()
  374. iter_tic = time.time()
  375. for step_id, data in enumerate(self.loader):
  376. self.status['data_time'].update(time.time() - iter_tic)
  377. self.status['step_id'] = step_id
  378. profiler.add_profiler_step(profiler_options)
  379. self._compose_callback.on_step_begin(self.status)
  380. data['epoch_id'] = epoch_id
  381. if self.cfg.get('amp', False):
  382. with amp.auto_cast(enable=self.cfg.use_gpu):
  383. # model forward
  384. outputs = model(data)
  385. loss = outputs['loss']
  386. # model backward
  387. scaled_loss = scaler.scale(loss)
  388. scaled_loss.backward()
  389. # in dygraph mode, optimizer.minimize is equal to optimizer.step
  390. scaler.minimize(self.optimizer, scaled_loss)
  391. else:
  392. # model forward
  393. outputs = model(data)
  394. loss = outputs['loss']
  395. # model backward
  396. loss.backward()
  397. self.optimizer.step()
  398. curr_lr = self.optimizer.get_lr()
  399. self.lr.step()
  400. if self.cfg.get('unstructured_prune'):
  401. self.pruner.step()
  402. self.optimizer.clear_grad()
  403. self.status['learning_rate'] = curr_lr
  404. if self._nranks < 2 or self._local_rank == 0:
  405. self.status['training_staus'].update(outputs)
  406. self.status['batch_time'].update(time.time() - iter_tic)
  407. self._compose_callback.on_step_end(self.status)
  408. if self.use_ema:
  409. self.ema.update()
  410. iter_tic = time.time()
  411. if self.cfg.get('unstructured_prune'):
  412. self.pruner.update_params()
  413. is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
  414. and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
  415. if is_snapshot and self.use_ema:
  416. # apply ema weight on model
  417. weight = copy.deepcopy(self.model.state_dict())
  418. self.model.set_dict(self.ema.apply())
  419. self.status['weight'] = weight
  420. self._compose_callback.on_epoch_end(self.status)
  421. if validate and is_snapshot:
  422. if not hasattr(self, '_eval_loader'):
  423. # build evaluation dataset and loader
  424. self._eval_dataset = self.cfg.EvalDataset
  425. self._eval_batch_sampler = \
  426. paddle.io.BatchSampler(
  427. self._eval_dataset,
  428. batch_size=self.cfg.EvalReader['batch_size'])
  429. # If metric is VOC, need to be set collate_batch=False.
  430. if self.cfg.metric == 'VOC':
  431. self.cfg['EvalReader']['collate_batch'] = False
  432. self._eval_loader = create('EvalReader')(
  433. self._eval_dataset,
  434. self.cfg.worker_num,
  435. batch_sampler=self._eval_batch_sampler)
  436. # if validation in training is enabled, metrics should be re-init
  437. # Init_mark makes sure this code will only execute once
  438. if validate and Init_mark == False:
  439. Init_mark = True
  440. self._init_metrics(validate=validate)
  441. self._reset_metrics()
  442. with paddle.no_grad():
  443. self.status['save_best_model'] = True
  444. self._eval_with_loader(self._eval_loader)
  445. if is_snapshot and self.use_ema:
  446. # reset original weight
  447. self.model.set_dict(weight)
  448. self.status.pop('weight')
  449. self._compose_callback.on_train_end(self.status)
  450. def _eval_with_loader(self, loader):
  451. sample_num = 0
  452. tic = time.time()
  453. self._compose_callback.on_epoch_begin(self.status)
  454. self.status['mode'] = 'eval'
  455. self.model.eval()
  456. if self.cfg.get('print_flops', False):
  457. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  458. self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
  459. self._flops(flops_loader)
  460. for step_id, data in enumerate(loader):
  461. self.status['step_id'] = step_id
  462. self._compose_callback.on_step_begin(self.status)
  463. # forward
  464. outs = self.model(data)
  465. # update metrics
  466. for metric in self._metrics:
  467. metric.update(data, outs)
  468. # multi-scale inputs: all inputs have same im_id
  469. if isinstance(data, typing.Sequence):
  470. sample_num += data[0]['im_id'].numpy().shape[0]
  471. else:
  472. sample_num += data['im_id'].numpy().shape[0]
  473. self._compose_callback.on_step_end(self.status)
  474. self.status['sample_num'] = sample_num
  475. self.status['cost_time'] = time.time() - tic
  476. # accumulate metric to log out
  477. for metric in self._metrics:
  478. metric.accumulate()
  479. metric.log()
  480. self._compose_callback.on_epoch_end(self.status)
  481. # reset metric states for metric may performed multiple times
  482. self._reset_metrics()
  483. def evaluate(self):
  484. with paddle.no_grad():
  485. self._eval_with_loader(self.loader)
  486. def predict(self,
  487. images,
  488. draw_threshold=0.5,
  489. output_dir='output',
  490. save_results=False):
  491. self.dataset.set_images(images)
  492. loader = create('TestReader')(self.dataset, 0)
  493. def setup_metrics_for_loader():
  494. # mem
  495. metrics = copy.deepcopy(self._metrics)
  496. mode = self.mode
  497. save_prediction_only = self.cfg[
  498. 'save_prediction_only'] if 'save_prediction_only' in self.cfg else None
  499. output_eval = self.cfg[
  500. 'output_eval'] if 'output_eval' in self.cfg else None
  501. # modify
  502. self.mode = '_test'
  503. self.cfg['save_prediction_only'] = True
  504. self.cfg['output_eval'] = output_dir
  505. self._init_metrics()
  506. # restore
  507. self.mode = mode
  508. self.cfg.pop('save_prediction_only')
  509. if save_prediction_only is not None:
  510. self.cfg['save_prediction_only'] = save_prediction_only
  511. self.cfg.pop('output_eval')
  512. if output_eval is not None:
  513. self.cfg['output_eval'] = output_eval
  514. _metrics = copy.deepcopy(self._metrics)
  515. self._metrics = metrics
  516. return _metrics
  517. if save_results:
  518. metrics = setup_metrics_for_loader()
  519. else:
  520. metrics = []
  521. imid2path = self.dataset.get_imid2path()
  522. anno_file = self.dataset.get_anno()
  523. clsid2catid, catid2name = get_categories(
  524. self.cfg.metric, anno_file=anno_file)
  525. # Run Infer
  526. self.status['mode'] = 'test'
  527. self.model.eval()
  528. if self.cfg.get('print_flops', False):
  529. flops_loader = create('TestReader')(self.dataset, 0)
  530. self._flops(flops_loader)
  531. results = []
  532. for step_id, data in enumerate(tqdm(loader)):
  533. self.status['step_id'] = step_id
  534. # forward
  535. outs = self.model(data)
  536. for _m in metrics:
  537. _m.update(data, outs)
  538. for key in ['im_shape', 'scale_factor', 'im_id']:
  539. if isinstance(data, typing.Sequence):
  540. outs[key] = data[0][key]
  541. else:
  542. outs[key] = data[key]
  543. for key, value in outs.items():
  544. if hasattr(value, 'numpy'):
  545. outs[key] = value.numpy()
  546. results.append(outs)
  547. # sniper
  548. if type(self.dataset) == SniperCOCODataSet:
  549. results = self.dataset.anno_cropper.aggregate_chips_detections(
  550. results)
  551. for _m in metrics:
  552. _m.accumulate()
  553. _m.reset()
  554. for outs in results:
  555. batch_res = get_infer_results(outs, clsid2catid)
  556. bbox_num = outs['bbox_num']
  557. start = 0
  558. for i, im_id in enumerate(outs['im_id']):
  559. image_path = imid2path[int(im_id)]
  560. image = Image.open(image_path).convert('RGB')
  561. image = ImageOps.exif_transpose(image)
  562. self.status['original_image'] = np.array(image.copy())
  563. end = start + bbox_num[i]
  564. bbox_res = batch_res['bbox'][start:end] \
  565. if 'bbox' in batch_res else None
  566. mask_res = batch_res['mask'][start:end] \
  567. if 'mask' in batch_res else None
  568. segm_res = batch_res['segm'][start:end] \
  569. if 'segm' in batch_res else None
  570. keypoint_res = batch_res['keypoint'][start:end] \
  571. if 'keypoint' in batch_res else None
  572. image = visualize_results(
  573. image, bbox_res, mask_res, segm_res, keypoint_res,
  574. int(im_id), catid2name, draw_threshold)
  575. self.status['result_image'] = np.array(image.copy())
  576. if self._compose_callback:
  577. self._compose_callback.on_step_end(self.status)
  578. # save image with detection
  579. save_name = self._get_save_image_name(output_dir, image_path)
  580. logger.info("Detection bbox results save in {}".format(
  581. save_name))
  582. image.save(save_name, quality=95)
  583. start = end
  584. def _get_save_image_name(self, output_dir, image_path):
  585. """
  586. Get save image name from source image path.
  587. """
  588. if not os.path.exists(output_dir):
  589. os.makedirs(output_dir)
  590. image_name = os.path.split(image_path)[-1]
  591. name, ext = os.path.splitext(image_name)
  592. return os.path.join(output_dir, "{}".format(name)) + ext
  593. def _get_infer_cfg_and_input_spec(self, save_dir, prune_input=True):
  594. image_shape = None
  595. im_shape = [None, 2]
  596. scale_factor = [None, 2]
  597. if self.cfg.architecture in MOT_ARCH:
  598. test_reader_name = 'TestMOTReader'
  599. else:
  600. test_reader_name = 'TestReader'
  601. if 'inputs_def' in self.cfg[test_reader_name]:
  602. inputs_def = self.cfg[test_reader_name]['inputs_def']
  603. image_shape = inputs_def.get('image_shape', None)
  604. # set image_shape=[None, 3, -1, -1] as default
  605. if image_shape is None:
  606. image_shape = [None, 3, -1, -1]
  607. if len(image_shape) == 3:
  608. image_shape = [None] + image_shape
  609. else:
  610. im_shape = [image_shape[0], 2]
  611. scale_factor = [image_shape[0], 2]
  612. if hasattr(self.model, 'deploy'):
  613. self.model.deploy = True
  614. for layer in self.model.sublayers():
  615. if hasattr(layer, 'convert_to_deploy'):
  616. layer.convert_to_deploy()
  617. export_post_process = self.cfg['export'].get(
  618. 'post_process', False) if hasattr(self.cfg, 'export') else True
  619. export_nms = self.cfg['export'].get('nms', False) if hasattr(
  620. self.cfg, 'export') else True
  621. export_benchmark = self.cfg['export'].get(
  622. 'benchmark', False) if hasattr(self.cfg, 'export') else False
  623. if hasattr(self.model, 'fuse_norm'):
  624. self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
  625. False)
  626. if hasattr(self.model, 'export_post_process'):
  627. self.model.export_post_process = export_post_process if not export_benchmark else False
  628. if hasattr(self.model, 'export_nms'):
  629. self.model.export_nms = export_nms if not export_benchmark else False
  630. if export_post_process and not export_benchmark:
  631. image_shape = [None] + image_shape[1:]
  632. # Save infer cfg
  633. _dump_infer_config(self.cfg,
  634. os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
  635. self.model)
  636. input_spec = [{
  637. "image": InputSpec(
  638. shape=image_shape, name='image'),
  639. "im_shape": InputSpec(
  640. shape=im_shape, name='im_shape'),
  641. "scale_factor": InputSpec(
  642. shape=scale_factor, name='scale_factor')
  643. }]
  644. if self.cfg.architecture == 'DeepSORT':
  645. input_spec[0].update({
  646. "crops": InputSpec(
  647. shape=[None, 3, 192, 64], name='crops')
  648. })
  649. if prune_input:
  650. static_model = paddle.jit.to_static(
  651. self.model, input_spec=input_spec)
  652. # NOTE: dy2st do not pruned program, but jit.save will prune program
  653. # input spec, prune input spec here and save with pruned input spec
  654. pruned_input_spec = _prune_input_spec(
  655. input_spec, static_model.forward.main_program,
  656. static_model.forward.outputs)
  657. else:
  658. static_model = None
  659. pruned_input_spec = input_spec
  660. # TODO: Hard code, delete it when support prune input_spec.
  661. if self.cfg.architecture == 'PicoDet' and not export_post_process:
  662. pruned_input_spec = [{
  663. "image": InputSpec(
  664. shape=image_shape, name='image')
  665. }]
  666. return static_model, pruned_input_spec
  667. def export(self, output_dir='output_inference'):
  668. self.model.eval()
  669. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  670. save_dir = os.path.join(output_dir, model_name)
  671. if not os.path.exists(save_dir):
  672. os.makedirs(save_dir)
  673. static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  674. save_dir)
  675. # dy2st and save model
  676. if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
  677. paddle.jit.save(
  678. static_model,
  679. os.path.join(save_dir, 'model'),
  680. input_spec=pruned_input_spec)
  681. else:
  682. self.cfg.slim.save_quantized_model(
  683. self.model,
  684. os.path.join(save_dir, 'model'),
  685. input_spec=pruned_input_spec)
  686. logger.info("Export model and saved in {}".format(save_dir))
  687. def post_quant(self, output_dir='output_inference'):
  688. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  689. save_dir = os.path.join(output_dir, model_name)
  690. if not os.path.exists(save_dir):
  691. os.makedirs(save_dir)
  692. for idx, data in enumerate(self.loader):
  693. self.model(data)
  694. if idx == int(self.cfg.get('quant_batch_num', 10)):
  695. break
  696. # TODO: support prune input_spec
  697. _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  698. save_dir, prune_input=False)
  699. self.cfg.slim.save_quantized_model(
  700. self.model,
  701. os.path.join(save_dir, 'model'),
  702. input_spec=pruned_input_spec)
  703. logger.info("Export Post-Quant model and saved in {}".format(save_dir))
  704. def _flops(self, loader):
  705. self.model.eval()
  706. try:
  707. import paddleslim
  708. except Exception as e:
  709. logger.warning(
  710. 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
  711. )
  712. return
  713. from paddleslim.analysis import dygraph_flops as flops
  714. input_data = None
  715. for data in loader:
  716. input_data = data
  717. break
  718. input_spec = [{
  719. "image": input_data['image'][0].unsqueeze(0),
  720. "im_shape": input_data['im_shape'][0].unsqueeze(0),
  721. "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
  722. }]
  723. flops = flops(self.model, input_spec) / (1000**3)
  724. logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
  725. flops, input_data['image'][0].unsqueeze(0).shape))
  726. def parse_mot_images(self, cfg):
  727. import glob
  728. # for quant
  729. dataset_dir = cfg['EvalMOTDataset'].dataset_dir
  730. data_root = cfg['EvalMOTDataset'].data_root
  731. data_root = '{}/{}'.format(dataset_dir, data_root)
  732. seqs = os.listdir(data_root)
  733. seqs.sort()
  734. all_images = []
  735. for seq in seqs:
  736. infer_dir = os.path.join(data_root, seq)
  737. assert infer_dir is None or os.path.isdir(infer_dir), \
  738. "{} is not a directory".format(infer_dir)
  739. images = set()
  740. exts = ['jpg', 'jpeg', 'png', 'bmp']
  741. exts += [ext.upper() for ext in exts]
  742. for ext in exts:
  743. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  744. images = list(images)
  745. images.sort()
  746. assert len(images) > 0, "no image found in {}".format(infer_dir)
  747. all_images.extend(images)
  748. logger.info("Found {} inference images in total.".format(
  749. len(images)))
  750. return all_images