postproc_test.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os, sys
  2. sys.path.append(os.path.split(os.path.realpath(__file__))[0] + "/../lib")
  3. from cnstream import *
  4. from cnstream_cpptest import *
  5. # in order to determine whether the python function is called by cpp
  6. init_called = False
  7. execute_called = False
  8. postproc_params = None
  9. obj_init_called = False
  10. obj_execute_called = False
  11. obj_postproc_params = None
  12. received_input_shapes = None
  13. received_output_shape = None
  14. class CustomPostproc(Postproc):
  15. def __init__(self):
  16. Postproc.__init__(self)
  17. def init(self, params):
  18. global init_called
  19. global postproc_params
  20. init_called = True
  21. postproc_params = params
  22. return True
  23. def execute(self, net_outputs, input_shapes, finfo):
  24. global execute_called
  25. execute_called = True
  26. global received_input_shapes
  27. global received_output_shape
  28. received_input_shapes = input_shapes
  29. received_output_shape = net_outputs[0].shape
  30. class CustomObjPostproc(ObjPostproc):
  31. def __init__(self):
  32. ObjPostproc.__init__(self)
  33. def init(self, params):
  34. global obj_init_called
  35. global obj_postproc_params
  36. obj_init_called = True
  37. obj_postproc_params = params
  38. return True
  39. def execute(self, net_outputs, input_shapes, finfo, obj):
  40. global obj_execute_called
  41. obj_execute_called = True
  42. global received_input_shapes
  43. global received_output_shape
  44. received_input_shapes = input_shapes
  45. received_output_shape = net_outputs[0].shape
  46. class TestPostproc:
  47. def test_postproc(self):
  48. params = {'pyclass_name' : 'test.postproc_test.CustomPostproc', 'param' : 'value'}
  49. assert cpptest_pypostproc(params)
  50. # test cpp call python init function success
  51. assert init_called
  52. # test custom parameters from cpp pass to python success
  53. assert postproc_params['param'] == 'value'
  54. # test cpp call python execute function success
  55. assert execute_called
  56. # check I/O shapes
  57. expected_input_shapes = [[4, 160, 40, 4]]
  58. expected_output_shape = (20, 1, 84)
  59. assert expected_input_shapes == received_input_shapes
  60. assert expected_output_shape == received_output_shape
  61. def test_obj_postproc(self):
  62. params = {'pyclass_name' : 'test.postproc_test.CustomObjPostproc', 'param' : 'value'}
  63. assert cpptest_pyobjpostproc(params)
  64. # test cpp call python init function success
  65. assert obj_init_called
  66. # test custom parameters from cpp pass to python success
  67. assert obj_postproc_params['param'] == 'value'
  68. # test cpp call python execute function success
  69. assert obj_execute_called
  70. # check I/O shapes
  71. expected_input_shapes = [[4, 160, 40, 4]]
  72. expected_output_shape = (20, 1, 84)
  73. assert expected_input_shapes == received_input_shapes
  74. assert expected_output_shape == received_output_shape