test_architectures.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import unittest
  18. import paddle
  19. import os
  20. import sys
  21. # add python path of PadleDetection to sys.path
  22. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
  23. if parent_path not in sys.path:
  24. sys.path.append(parent_path)
  25. try:
  26. from ppdet.utils.check import enable_static_mode, logger
  27. from ppdet.modeling.tests.decorator_helper import prog_scope
  28. from ppdet.core.workspace import load_config, merge_config, create
  29. except ImportError as e:
  30. if sys.argv[0].find('static') >= 0:
  31. logger.error("Importing ppdet failed when running static model "
  32. "with error: {}\n"
  33. "please try:\n"
  34. "\t1. run static model under PaddleDetection/static "
  35. "directory\n"
  36. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  37. "dynamic version firstly.".format(e))
  38. sys.exit(-1)
  39. else:
  40. raise e
  41. class TestFasterRCNN(unittest.TestCase):
  42. def setUp(self):
  43. self.set_config()
  44. self.cfg = load_config(self.cfg_file)
  45. self.detector_type = self.cfg['architecture']
  46. def set_config(self):
  47. self.cfg_file = 'configs/faster_rcnn_r50_1x.yml'
  48. @prog_scope()
  49. def test_train(self):
  50. model = create(self.detector_type)
  51. inputs_def = self.cfg['TrainReader']['inputs_def']
  52. inputs_def['image_shape'] = [3, None, None]
  53. feed_vars, _ = model.build_inputs(**inputs_def)
  54. train_fetches = model.train(feed_vars)
  55. @prog_scope()
  56. def test_test(self):
  57. inputs_def = self.cfg['EvalReader']['inputs_def']
  58. inputs_def['image_shape'] = [3, None, None]
  59. model = create(self.detector_type)
  60. feed_vars, _ = model.build_inputs(**inputs_def)
  61. test_fetches = model.eval(feed_vars)
  62. class TestMaskRCNN(TestFasterRCNN):
  63. def set_config(self):
  64. self.cfg_file = 'configs/mask_rcnn_r50_1x.yml'
  65. @unittest.skip(
  66. reason="It should be fixed to adapt https://github.com/PaddlePaddle/Paddle/pull/23797"
  67. )
  68. class TestCascadeRCNN(TestFasterRCNN):
  69. def set_config(self):
  70. self.cfg_file = 'configs/cascade_rcnn_r50_fpn_1x.yml'
  71. @unittest.skipIf(
  72. paddle.version.full_version < "1.8.4",
  73. "Paddle 2.0 should be used for YOLOv3 takes scale_x_y as inputs, "
  74. "disable this unittest for Paddle major version < 2")
  75. class TestYolov3(TestFasterRCNN):
  76. def set_config(self):
  77. self.cfg_file = 'configs/yolov3_darknet.yml'
  78. class TestRetinaNet(TestFasterRCNN):
  79. def set_config(self):
  80. self.cfg_file = 'configs/retinanet_r50_fpn_1x.yml'
  81. class TestSSD(TestFasterRCNN):
  82. def set_config(self):
  83. self.cfg_file = 'configs/ssd/ssd_mobilenet_v1_voc.yml'
  84. if __name__ == '__main__':
  85. enable_static_mode()
  86. unittest.main()