frame_va_test.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import os, sys
  2. sys.path.append(os.path.split(os.path.realpath(__file__))[0] + "/../lib")
  3. from cnstream import *
  4. from cnstream_cpptest import *
  5. # import cv2
  6. def assert_eq(actual_val, expect_val):
  7. assert actual_val == expect_val, "Actual value is " + str(actual_val) + ". Expect " + str(expect_val)
  8. class TestFrameVa:
  9. def test_dataframe(self):
  10. frame = CNFrameInfo("stream_id_0")
  11. frame_id = 111
  12. width = 1280
  13. height = 720
  14. stride = [1282, 1284]
  15. fmt = CNDataFormat.CN_PIXEL_FORMAT_YUV420_NV12
  16. dev_type = DevType.MLU
  17. dev_id = 0
  18. dst_device_id = 1
  19. src_data_frame = CNDataFrame()
  20. src_data_frame.frame_id = frame_id
  21. src_data_frame.width = width
  22. src_data_frame.height = height
  23. src_data_frame.stride = stride
  24. src_data_frame.fmt = fmt
  25. src_data_frame.ctx.dev_type = dev_type
  26. src_data_frame.ctx.dev_id = dev_id
  27. src_data_frame.dst_device_id = dst_device_id
  28. set_data_frame(frame, src_data_frame)
  29. data_frame = frame.get_cn_data_frame()
  30. # data
  31. src_data_y_size = stride[0] * height
  32. src_data_uv_size = stride[1] * height / 2
  33. assert_eq(data_frame.data(0).get_size(), src_data_y_size)
  34. assert_eq(data_frame.data(1).get_size(), src_data_uv_size)
  35. assert_eq(data_frame.frame_id, frame_id)
  36. assert_eq(data_frame.fmt, fmt)
  37. assert_eq(data_frame.width, width)
  38. assert_eq(data_frame.height, height)
  39. # stride
  40. assert_eq(data_frame.stride[0], stride[0])
  41. assert_eq(data_frame.stride[1], stride[1])
  42. modify_stride = [1286, 1288]
  43. data_frame.stride = modify_stride
  44. data_frame_tmp = frame.get_cn_data_frame()
  45. assert_eq(data_frame_tmp.stride[0], modify_stride[0])
  46. assert_eq(data_frame_tmp.stride[1], modify_stride[1])
  47. data_frame.stride = stride
  48. assert_eq(data_frame.ctx.dev_type, dev_type)
  49. assert_eq(data_frame.ctx.dev_id, dev_id)
  50. assert_eq(data_frame.dst_device_id, dst_device_id)
  51. modify_dst_dev_id = 4
  52. data_frame.dst_device_id = modify_dst_dev_id
  53. data_frame_tmp = frame.get_cn_data_frame()
  54. assert_eq(data_frame_tmp.dst_device_id, modify_dst_dev_id)
  55. # functions
  56. assert_eq(data_frame.get_planes(), 2)
  57. assert_eq(data_frame.get_plane_bytes(0), src_data_y_size)
  58. assert_eq(data_frame.get_plane_bytes(1), src_data_uv_size)
  59. assert_eq(data_frame.get_bytes(), src_data_y_size + src_data_uv_size)
  60. assert_eq(data_frame.has_bgr_image(), False)
  61. # cv Mat
  62. img = data_frame.image_bgr()
  63. assert_eq(img.shape, (height, width, 3))
  64. assert_eq(data_frame.has_bgr_image(), True)
  65. # if modify the numpy array, the image_bgr of the data_frame will be modified
  66. # cv2.imwrite("./test_img.jpg", img)
  67. for i in range(300):
  68. for j in range(200):
  69. img[i, j, 0] = 0
  70. img[i, j, 1] = 0
  71. img[i, j, 2] = 0
  72. img_res = data_frame.image_bgr()
  73. assert img.all() == img_res.all()
  74. # cv2.imwrite("./test_img_res.jpg", img_res)
  75. def test_cninfer_objects(self):
  76. frame = CNFrameInfo("stream_id_0")
  77. set_infer_objs(frame, CNInferObjs())
  78. objs_holder = frame.get_cn_infer_objects()
  79. # no object is in objs_holder
  80. assert_eq(len(objs_holder.objs), 0)
  81. # Add an object to objs_holder
  82. class_id = "1"
  83. track_id = "2"
  84. score = 0.5
  85. bbox = CNInferBoundingBox(0.1, 0.2, 0.5, 0.6)
  86. user_data = "hi cnstream"
  87. obj = CNInferObject()
  88. obj.id = class_id
  89. obj.track_id = track_id
  90. obj.score = score
  91. obj.bbox = bbox
  92. py_collection = obj.get_py_collection()
  93. py_collection["user_data"] = user_data
  94. objs_holder.push_back(obj)
  95. # Check the object be added
  96. assert_eq(len(objs_holder.objs), 1)
  97. assert_eq(objs_holder.objs[0].id, class_id)
  98. assert_eq(objs_holder.objs[0].track_id, track_id)
  99. assert_eq(objs_holder.objs[0].score, score)
  100. assert_eq(objs_holder.objs[0].bbox.x, bbox.x)
  101. assert_eq(objs_holder.objs[0].bbox.y, bbox.y)
  102. assert_eq(objs_holder.objs[0].bbox.w, bbox.w)
  103. assert_eq(objs_holder.objs[0].bbox.h, bbox.h)
  104. assert_eq(len(objs_holder.objs[0].get_py_collection()), 1)
  105. assert "user_data" in objs_holder.objs[0].get_py_collection()
  106. assert_eq(objs_holder.objs[0].get_py_collection()["user_data"], user_data)
  107. # Add an attr to the object
  108. attr0_id = 0
  109. attr0_value = 5
  110. attr0_score = 0.8
  111. attr0 = CNInferAttr(attr0_id, attr0_value, attr0_score)
  112. objs_holder.objs[0].add_attribute("attr0", attr0)
  113. attr1 = CNInferAttr(attr0_id, attr0_value, attr0_score)
  114. attr1.id = 1
  115. attr1.value = 4
  116. attr1.score = 0.6
  117. objs_holder.objs[0].add_attribute("attr1", attr1)
  118. assert_eq(objs_holder.objs[0].get_attribute("attr0").id, attr0_id)
  119. assert_eq(objs_holder.objs[0].get_attribute("attr0").value, attr0_value)
  120. assert abs(objs_holder.objs[0].get_attribute("attr0").score - attr0_score) < 0.000001
  121. assert_eq(objs_holder.objs[0].get_attribute("attr1").id, attr1.id)
  122. assert_eq(objs_holder.objs[0].get_attribute("attr1").value, attr1.value)
  123. assert_eq(objs_holder.objs[0].get_attribute("attr1").score, attr1.score)
  124. # Add an extra attr to the object
  125. extra_attr_val = "extra_attribute"
  126. objs_holder.objs[0].add_extra_attribute("extra0", extra_attr_val)
  127. assert_eq(objs_holder.objs[0].get_extra_attribute("extra0"), extra_attr_val)
  128. # Add a feature to the object
  129. feature0 = [0.15, 0.22, 0.37]
  130. objs_holder.objs[0].add_feature("feat0", feature0)
  131. feature1 = [1, 2, 3, 4]
  132. objs_holder.objs[0].add_feature("feat1", feature1)
  133. diff = [abs(objs_holder.objs[0].get_feature("feat0")[i] - feature0[i]) for i in range(3)]
  134. assert diff < [0.000001, 0.000001, 0.000001], diff
  135. assert_eq(objs_holder.objs[0].get_feature("feat1"), feature1)
  136. assert_eq(len(objs_holder.objs[0].get_features()), 2)