Browse Source

first commit

zbc 3 weeks ago
commit
f9c7009403
100 changed files with 11943 additions and 0 deletions
  1. 8 0
      code/data_manage/.idea/.gitignore
  2. 1 0
      code/data_manage/.idea/.name
  3. 15 0
      code/data_manage/.idea/data_manage.iml
  4. 14 0
      code/data_manage/.idea/deployment.xml
  5. 7 0
      code/data_manage/.idea/encodings.xml
  6. 7 0
      code/data_manage/.idea/inspectionProfiles/profiles_settings.xml
  7. 10 0
      code/data_manage/.idea/misc.xml
  8. 8 0
      code/data_manage/.idea/modules.xml
  9. 310 0
      code/data_manage/README.md
  10. BIN
      code/data_manage/README.pdf
  11. 24 0
      code/data_manage/config/baidupan.conf
  12. 105 0
      code/data_manage/config/data_manage.conf
  13. 72 0
      code/data_manage/config/utils_db.conf
  14. 89 0
      code/data_manage/gen_data/check_data_step2.py
  15. 142 0
      code/data_manage/gen_data/classFeatureImages_step3.py
  16. 67 0
      code/data_manage/gen_data/class_operate.py
  17. 43 0
      code/data_manage/gen_data/genAnn_step5.py
  18. 70 0
      code/data_manage/gen_data/gen_class_index_step4.py
  19. 145 0
      code/data_manage/gen_data/gen_pb_tfrecord_step7.py
  20. 65 0
      code/data_manage/gen_data/other2jpg.py
  21. 240 0
      code/data_manage/gen_data/splitTrainVal_step6.py
  22. 75 0
      code/data_manage/gen_data/utils/coorUtil.py
  23. 99 0
      code/data_manage/gen_data/utils/fileUtil.py
  24. 32 0
      code/data_manage/gen_data/utils/pathUtil.py
  25. 353 0
      code/data_manage/gen_data/utils/profileUtil.py
  26. 39 0
      code/data_manage/gen_data/utils/strUtil.py
  27. 151 0
      code/data_manage/gen_data/video2image_step1.py
  28. 16 0
      code/data_manage/main.py
  29. 57 0
      code/data_manage/run_gen.py
  30. 177 0
      code/data_manage/test_util/Qt5/checkfile/Form.py
  31. 44 0
      code/data_manage/test_util/Qt5/checkfile/Form.spec
  32. 172 0
      code/data_manage/test_util/Qt5/checkfile/Form.ui
  33. 1 0
      code/data_manage/test_util/Qt5/checkfile/error_log.txt
  34. 8 0
      code/data_manage/test_util/Qt5/checkfile/utils/__init__.py
  35. 87 0
      code/data_manage/test_util/Qt5/checkfile/utils/check_data_step2.py
  36. 118 0
      code/data_manage/test_util/Qt5/checkfile/utils/classFeatureImages_step3.py
  37. 75 0
      code/data_manage/test_util/Qt5/checkfile/utils/coorUtil.py
  38. 17 0
      code/data_manage/test_util/Qt5/checkfile/utils/fileUtil.py
  39. 31 0
      code/data_manage/test_util/Qt5/checkfile/utils/pathUtil.py
  40. 53 0
      code/data_manage/test_util/Qt5/checkfile/utils/profileUtil.py
  41. 3 0
      code/data_manage/test_util/Qt5/error_log.txt
  42. 151 0
      code/data_manage/test_util/Qt5/excel/ModifyTree.py
  43. 8 0
      code/data_manage/test_util/Qt5/excel/__init__.py
  44. 5 0
      code/data_manage/test_util/Qt5/excel/data/excel.conf
  45. BIN
      code/data_manage/test_util/Qt5/excel/data/image.xlsx
  46. 252 0
      code/data_manage/test_util/Qt5/frame.py
  47. 167 0
      code/data_manage/test_util/Qt5/frame.ui
  48. 524 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/frame.py
  49. 332 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/frame.ui
  50. 8 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/__init__.py
  51. 93 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/check_data_step2.py
  52. 154 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/classFeatureImages_step3.py
  53. 39 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/genAnn_step5.py
  54. 58 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/gen_class_index_step4.py
  55. 138 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/gen_pb_tfrecord_step7.py
  56. 232 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/splitTrainVal_step6.py
  57. 21 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/test.py
  58. 8 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/__init__.py
  59. 75 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/coorUtil.py
  60. 97 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/fileUtil.py
  61. 32 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/pathUtil.py
  62. 245 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/profileUtil.py
  63. 80 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/select_data_util.py
  64. 39 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/strUtil.py
  65. 150 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/write_xml_jpg_2_all_data_util.py
  66. 251 0
      code/data_manage/test_util/Qt5/gen_tfrecord_ui/run_linux.py
  67. 515 0
      code/data_manage/test_util/Qt5/mysql_find.py
  68. 8 0
      code/data_manage/test_util/Qt5/videos2imgs/__init__.py
  69. 519 0
      code/data_manage/test_util/Qt5/videos2imgs/form1.3.py
  70. 44 0
      code/data_manage/test_util/Qt5/videos2imgs/form1.3.spec
  71. 28 0
      code/data_manage/test_util/filter_file.py
  72. 8 0
      code/data_manage/test_util/find_diff/__init__.py
  73. 111 0
      code/data_manage/test_util/find_diff/find_different.py
  74. 8 0
      code/data_manage/test_util/img_mask/__init__.py
  75. 257 0
      code/data_manage/test_util/img_mask/imgmask.py
  76. 33 0
      code/data_manage/test_util/img_mask/pathUtil.py
  77. 29 0
      code/data_manage/test_util/line2point/draw_line.py
  78. 270 0
      code/data_manage/test_util/line2point/get_convert_belt_info(2).py
  79. 332 0
      code/data_manage/test_util/line2point/hough_line.py
  80. 279 0
      code/data_manage/test_util/line2point/line2point.py
  81. 260 0
      code/data_manage/test_util/line2point/linePointAngleUtil.py
  82. 48 0
      code/data_manage/test_util/merger_tfrecord/merger_tfrecord.py
  83. 65 0
      code/data_manage/test_util/other2jpg/other2jpg.py
  84. 39 0
      code/data_manage/test_util/replace_color/parse_xml.py
  85. 70 0
      code/data_manage/test_util/replace_color/replace_color.py
  86. 51 0
      code/object_detection.sh
  87. 301 0
      code/yolov5/README.md
  88. 260 0
      code/yolov5/detect.py
  89. 241 0
      code/yolov5/detect1.py
  90. 596 0
      code/yolov5/export.py
  91. 3 0
      code/yolov5/export.sh
  92. 145 0
      code/yolov5/hubconf.py
  93. 0 0
      code/yolov5/models/__init__.py
  94. 842 0
      code/yolov5/models/common.py
  95. 122 0
      code/yolov5/models/experimental.py
  96. 59 0
      code/yolov5/models/hub/anchors.yaml
  97. 51 0
      code/yolov5/models/hub/yolov3-spp.yaml
  98. 41 0
      code/yolov5/models/hub/yolov3-tiny.yaml
  99. 51 0
      code/yolov5/models/hub/yolov3.yaml
  100. 48 0
      code/yolov5/models/hub/yolov5-bifpn.yaml

+ 8 - 0
code/data_manage/.idea/.gitignore

@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml

+ 1 - 0
code/data_manage/.idea/.name

@@ -0,0 +1 @@
+run_gen.py

+ 15 - 0
code/data_manage/.idea/data_manage.iml

@@ -0,0 +1,15 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$">
+      <sourceFolder url="file://$MODULE_DIR$/models-master" isTestSource="false" />
+      <sourceFolder url="file://$MODULE_DIR$/models-master/research/object_detection" isTestSource="false" />
+      <excludeFolder url="file://$MODULE_DIR$/venv" />
+    </content>
+    <orderEntry type="jdk" jdkName="Python 3.9 (python3.9)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+  <component name="PackageRequirementsSettings">
+    <option name="requirementsPath" value="" />
+  </component>
+</module>

+ 14 - 0
code/data_manage/.idea/deployment.xml

@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="PublishConfigData">
+    <serverData>
+      <paths name="fengyang@192.168.20.249:22 password">
+        <serverdata>
+          <mappings>
+            <mapping local="$PROJECT_DIR$" web="/" />
+          </mappings>
+        </serverdata>
+      </paths>
+    </serverData>
+  </component>
+</project>

+ 7 - 0
code/data_manage/.idea/encodings.xml

@@ -0,0 +1,7 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="Encoding">
+    <file url="file://$PROJECT_DIR$/gen_data/classFeatureImages_step3.py" charset="UTF-8" />
+    <file url="PROJECT" charset="UTF-8" />
+  </component>
+</project>

+ 7 - 0
code/data_manage/.idea/inspectionProfiles/profiles_settings.xml

@@ -0,0 +1,7 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="PROJECT_PROFILE" value="Default" />
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>

+ 10 - 0
code/data_manage/.idea/misc.xml

@@ -0,0 +1,10 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="JavaScriptSettings">
+    <option name="languageLevel" value="ES6" />
+  </component>
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (python3.9)" project-jdk-type="Python SDK" />
+  <component name="PyPackaging">
+    <option name="earlyReleasesAsUpgrades" value="true" />
+  </component>
+</project>

+ 8 - 0
code/data_manage/.idea/modules.xml

@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/data_manage.iml" filepath="$PROJECT_DIR$/.idea/data_manage.iml" />
+    </modules>
+  </component>
+</project>

+ 310 - 0
code/data_manage/README.md

@@ -0,0 +1,310 @@
+### data_manage项目说明文档
+
+#### 项目文件目录树:
+
+data_manage
+│  cookies.pickle    百度云账号的cookies数据
+│  run_gen.py    运行data_manage总程序入口
+│  run_pan.py    运行百度云盘自动化程序入口(创建分享链接功能有问题)
+│  
+├─BaiDuPan     百度云盘自动化与钉钉自动发送百度云分享链接
+│  │  baiDuPan.py
+│  │  bypy_operate.py
+│  │  logid.py
+│  │  login_pan.py 百度云盘模拟登录(userName为百度云账号,password为百度云账号密码)
+│  │  
+│  └─utils    工具类
+│      │  fileUtil.py
+│      │  pathUtil.py
+│      │  profileUtil.py 
+│          
+├─config    配置文件存放路径
+│      baidupan.conf    百度云相关配置文件
+│      data_manage.conf    视频切分图片及数据预处理配置文件(step1-step8)
+│      utils_db.conf    数据库操作相关配置文件(step8)
+│      
+├─database    生成数据库及相关数据操作
+│  │  create_db.py    创建数据库
+│  │  delete_tables_data.py    删除表
+│  │  insert_tables_data.py    插入表
+│  │  update_tables_data.py    更新表
+│  │  
+│  └─utils 工具类
+│      │  connect_db_util.py
+│      │  delete_data_util.py
+│      │  insert_data_util.py
+│      │  mkdir.py
+│      │  pathUtil.py
+│      │  profileUtil.py
+│      │  select_data_util.py
+│      │  update_data_util.py
+│      │  
+│      └─utils_sql_create_DBImg    存放创建数据库的sql语句配置文件夹(见数据库字段说明表)
+│               utils_sql_class_state.conf    创建class_state表
+│               utils_sql_image_basic.conf    创建image_basic表
+│               utils_sql_img_bbox.conf    创建img_bbox表
+│               utils_sql_scene_basic.conf    创建scene_basic表
+│               utils_sql_spot_basic.conf    创建spot_basic表
+│          
+├─gen_data 数据预处理
+│  │  check_data_step2.py    对标注公司标注好的文件进行二次质检,确保xml文件和jpg文件一一对应 step2_检查是否存在漏标的图片
+│  │  classFeatureImages_step3.py    将标注后的图像中的特征图像提取出来后按照类别分类 step3_截取图片中标注后的特征图像,并对其分类处理
+│  │  class_operate.py    小工具功能:将本地xml文件中多余的类删除,或者是修改已经存在的旧类
+│  │  genAnn_step5.py    生成图片的具体信息:文件路径、特征图像的坐标等关键信息 step5_生成图片信息文件
+│  │  gen_class_index_step4.py    生成图片的所有分类信息 step4_生成类别信息文件
+│  │  gen_pb_tfrecord_step7.py    生成pbtxt文件和Tfrecord文件 step7_生成特定的文件
+│  │  splitTrainVal_step6.py    在根据类别信息文件对图片进行分类后的基础上进行训练集和测试集的划分操作 step6_训练集和验证集的划分
+│  │  video2image_step1.py    视频转换为图片 step1_视频切分为图片(已将此步骤利用pyqt生成了可交互的exe程序,请在test_util/QT5目录下查看)
+│  │  
+│  └─utils    工具类
+│          coorUtil.py    检查标注框在原图片中的位置是否出现超越背景图片总长宽的情况
+│          fileUtil.py    文件筛选工具类
+│          pathUtil.py    路径操作工具类
+│          profileUtil.py    配置文件操作工具类
+│          strUtil.py    字符串工具类
+│          
+└─test_util
+    │  filter_file.py    获取某一目录下面所有子目录中前一半的视频文件(洛河电厂项目数据中存在红外镜头拍摄的视频,只取正常镜头拍摄的视频,文件可忽略)
+    │  
+    ├─find_diff    特殊场景下批量生成相同内容的xml文件(此文件为特殊情况,可忽略)
+    │      find_different.py
+    │      
+    ├─img_mask    利用旧的特征图像为前景将旧图片为背景,生成新的训练数据集
+    │      imgmask.py
+    │      pathUtil.py
+    │      
+    ├─merger_tfrecord    将两个tfrecord文件合并为一个
+    │      merger_tfrecord.py
+    │  
+    ├─other2jpg    将其他类型的图片转化为jpg格式(支持36类)
+    │      other2jpg.py
+    │ 
+    ├─Qt5 可生成exe程序
+    │  ├─excel 将总标注数据整理成表格的形式,利用程序实现鼠标点击即可一键导出选中的部分数据生成某个项目的标注数据说明文档交给标注公司,指导其标注工作
+    │  │  │  ModifyTree.py 主程序
+    │  │  │  
+    │  │  └─data
+    │  │      │  image.xlsx 存放标注数据的excel表
+    │  │      │  excel.conf 程序的配置文件
+    │  │      │ 
+    │  │      └─images  存放标注数据的示例图片
+    │  │  
+    │  └─videos2imgs 视频切分图片
+    │         form1.3.py
+    │        
+    └─replace_color 语义分割图中的颜色归并,将多色块图中的多种颜色归并为一种颜色(多分类语义图中只保留树,将其他分类的色块都归并为背景)
+            parse_xml.py
+            replace_color.py
+
+
+
+####  1.运行run_pan.py:
+
+##### 1.1 更新cookies:(配置文件:baidupan.conf)
+
+​	当程序显示无效cookies时,需要运行此功能更新cookies。
+
+​		首先检查login_pan.py文件中userName和password是否正确。然后运行该功能,模拟登录程序会自动填写账号和密码信息,但是需要人工完成辅助验证功能,完成后在程序运行命令行按回车键即可更新cookies。
+
+##### 1.2 上传文件(配置文件:baidupan.conf)
+
+​	用于将本地文件上传到百度云路径下。
+
+​		更改config/baidupan.conf配置文件[upload]中的相关信息,其中网盘的路径是相对路径,网盘上传路径会以/app/bypy文件夹为根目录,且不能更换。更改好配置文件后运行该功能即可实现文件的上传。
+
+##### 1.3 下载文件(配置文件:baidupan.conf)
+
+​	用于将网盘文件下载到本地路径。
+
+​		更改config/baidupan.conf配置文件[download]中的相关信息,网盘路径同样为相对路径,修改好配置文件运行该功能即可实现文件下载。
+
+##### 1.4 创建分享链接(配置文件:baidupan.conf)
+
+​	将百度云盘中的文件分享给其他人
+
+​		更改config/baidupan.conf配置文件[download]中的相关信息,网盘路径同样为相对路径,修改好配置文件运行该功能即可实现创建分享链接。(创建分享链接功能会经常出问题,该功能慎用)
+
+#### 2.运行run_gen.py:()
+
+运行流程:步骤1:video2image_step1>标注公司
+
+​				  步骤2:check_data_step2>classFeatureImages_step3>gen_class_index_step4>genAnn_step5>splitTrainVal_step6>gen_pb_tfrecord_step7
+
+​				  步骤3:存入数据库。
+
+
+
+##### 2.1 视频切分图片操作(step1)(配置文件:data_manage.conf)
+
+​	更改config/data_manage.conf配置文件[vidoe2image]中的相关信息, video_dir为视频文件夹路径,save_dir为切分图片的保存路径。视频可以按帧数切分也可以按总数切分,此功能只支持按帧数切分,且建议使用test_util/QT5/videos2imgs目录下的程序,该程序是此功能的升级版,支持指定照片总数。有很友好的可视化界面及其对应的exe执行程序。
+
+##### 2.2  数据预处理(step2-step7)
+
+​	此功能一般情况下只需要修改config/data_manage.conf配置文件[vidoe2image]中dir字段即可,其他即为默认属性。step2-step7为默认一整套功能。
+
+​	step2:质检  (配置文件:data_manage.conf)
+
+​		对标注公司返回的标注结果进行质检,检查xml文件和jpg文件是否一一对应。如果有不对应的文件则会被提出来放到redundant_data文件夹下
+
+​	step3: 截取标注框并分类  (配置文件:data_manage.conf)
+
+​		截取标注的矩形框并按照类别分别放在个类别的文件夹下面,将所有类别图片统一放在class_img_dir文件夹下,在此期间还会检查标注框的边界问题,如果标注框有问题如坐标超过对应背景图像宽或高则将此图片及其对应的xml文件转移到redundant_data文件夹下
+
+​	step4:生成类别信息文件  (配置文件:data_manage.conf)
+
+​		对应的类别信息文件有:存放标注种类的json文件夹和存放标注种类的txt文件夹。里面包含了该项目中所以标注框的set集合
+
+​	step5:生成图片信息文件  (配置文件:data_manage.conf)
+
+​		提取图片的具体信息:文件路径、特征图像的坐标等关键信息,将信息数据分别以txt和csv两种格式保存,其中txt分为两种形式,一种为普通txt文件,另一种以方便yolo5读取的txt格式保存。目前默认yolo_txt格式。具体可以在config/data_manage.conf中[data_manage]的flag字段中更改。
+
+​	step6:训练集和验证集划分  (配置文件:data_manage.conf)
+
+​		在根据类别信息文件对图片进行分类后的基础上进行训练集和测试集的划分操作,确保了训练集和验证集在数据类别上的完整性。
+
+​	step7: 生成训练数据和测试数据的tfrecord文件,生成ob.pbtxt文件
+
+​		没啥可说的,就是生成了三个文件。路径都是默认即可。一般情况下不需要改动。
+
+##### 2.3 数据库操作
+
+​	
+
+```
+img_basic :图片文件信息存放表
+        +--------------------+--------------+------+-----+---------+------------------+
+        | Field              | Type         | Null | Key | Default | Extra            |
+        +--------------------+--------------+------+-----+---------+------------------+
+        | id                 | int(11)      | NO   | PRI | NULL    | auto_increment   | 图片文件ID
+        | path               | varchar(128) | NO   |     | NULL    |                  | 图片路径
+        | filename           | varchar(128) | YES  |     | NULL    |                  | 文件名称
+        | width              | int(11)      | YES  |     | NULL    |                  | 图片宽
+        | height             | int(11)      | YES  |     | NULL    |                  | 图片高
+        | depth              | int(11)      | YES  |     | NULL    |                  | 图片深度
+        | spot_id            | int(11)      | YES  |     | NULL    |                  | 项目ID(地点ID)
+        | create_time        | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 创建时间
+        | update_time        | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 更新时间
+        +--------------------+--------------+------+-----+---------+------------------+
+
+
+img_bbox :图片标注框信息存放表
+        +--------------------+--------------+------+-----+---------+------------------+
+        | Field              | Type         | Null | Key | Default | Extra            |
+        +--------------------+--------------+------+-----+---------+------------------+
+        | id                 | int(11)      | NO   | PRI | NULL    | auto_increment   | 标注框ID
+        | file_id            | int(11)      | NO   |     | NULL    |                  | 图片文件ID
+        | xmin               | int(11)      | YES  |     | NULL    |                  | xmin坐标
+        | ymin               | int(11)      | YES  |     | NULL    |                  | ymin坐标
+        | xmax               | int(11)      | YES  |     | NULL    |                  | xmax坐标
+        | ymax               | int(11)      | YES  |     | NULL    |                  | ymax坐标
+        | class_name         | varchar(128) | YES  |     | NULL    |                  | 类别名称
+        | state_name         | varchar(128) | YES  |     | NULL    |                  | 状态名称
+        | bbox_ratio         | double       | YES  |     | NULL    |                  | 标注框转背景图片比例
+        | create_time        | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 创建时间
+        | update_time        | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 更新时间
+        +--------------------+--------------+------+-----+---------+------------------+
+
+
+class_state :状态类别信息表
+        +-----------+-------------+------+-----+---------+-------+
+        | Field     | Type        | Null | Key | Default | Extra |
+        +-----------+-------------+------+-----+---------+-------+
+        |state_name | varchar(40) | NO   | PRI | NULL    |       | 状态名称(E1-0)
+        |class_name | varchar(40) | NO   |     | NULL    |       | 类别名称(E1)
+        |is_common  | varchar(20) | NO   |     | NULL    |       | 是否通用
+        +-----------+-------------+------+-----+---------+-------+
+
+scene_basic :场景信息表
+        +-------------+--------------+------+-----+---------+------------------+
+       | Field       | Type         | Null | Key | Default | Extra            | 
+        +-------------+--------------+------+-----+---------+------------------+
+        | scene_id    | int(11)      | NO   | PRI | NULL    | auto_increment   | 场景ID
+        | description | varchar(128) | YES  |     | NULL    |                  | 描述
+        | scene       | varchar(128) | YES  |     | NULL    |                  | 场景名称
+        | create_time | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 创建时间
+        | update_time | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 更新时间
+        +-------------+--------------+------+-----+---------+------------------+
+
+spot_basic:项目信息表
+        +--------------+--------------+------+-----+---------+------------------+
+        | Field        | Type         | Null | Key | Default | Extra            |
+        +--------------+--------------+------+-----+---------+------------------+ 
+        | spot_id      | int(11)      | NO   | PRI | NULL    | auto_increment   | 项目ID(地点ID)
+        | scene_id     | int(11)      | NO   |     | NULL    |                  | 场景ID
+        | spot         | varchar(128) | NO   | PRI | NULL    |                  | 项目名称
+        | spot_CN      | varchar(128) | NO   |     | NULL    |                  | 项目中文名称
+        | project_path | varchar(128) | NO   | PRI | NULL    |                  | 项目路径
+        | create_time  | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 创建时间
+        | update_time  | datetime     | YES  |     | NULL    | CURRENT_TIMESTAMP| 更新时间
+        +--------------+--------------+------+-----+---------+------------------+
+```
+
+
+
+​	首先如果是第一次运行,在相应数据库没有建立的情况下,需要去database目录下运行create_db.py。数据库的IP、user、password等信息需要在config/utils_db.conf中的[connect]下修改相关字段,来确保数据库的正常连接
+
+​	a. 数据插入(直接通过解析所有图片的xml文件进行数据的插入操作)(配置文件:data_manage.conf, utils_db.conf)
+
+​		需要修改config/data_manage.conf配置文件[database]中相关字段,里面保存着项目的路径以及数据相对于项目路径下的文件夹
+
+​		第一次插入数据时,场景数据表(scene_basic)和项目数据表(spot_basic)为空,相关程序会引导用户对这两张表进行字段的插入操作,如果需要添加新应用场景或者新项目名称,这种方式同样适用。插入的所有数据都是以项目ID(spot_id)来作为索引的。所以当需要新增原来已经存在的项目数据时,数据库会通过相关命令将原有的数据进行删除,然后再插入现有的数据。
+
+​	b.状态类别删除(配置文件:utils_db.conf)
+
+​		该功能是删除原来项目中不正确的分类,如S1-0。程序会引导用户输入项目ID和需要删除的状态分类。注意:状态类别只能一个一个删除,例如需要删除S0小分类下的所有状态,只能分别通过删除S0-0和S0-1来实现。
+
+​	c. 状态类别修改(配置文件:utils_db.conf)
+
+​		该功能是修改原来项目中需要替换的分类。和上面的删除一样,程序会引导用户输入需要修改的状态类别及新的状态类别
+
+##### 2.4 小工具
+
+ 1. 状态类别删除(配置文件:data_manage.conf)
+
+    该功能是直接通过解析本地xml文件来删除其中的类别来实现类别删除。请在操作前仔细确认需要删除的状态类别,删除的状态类别需要通过修改config/data_manage.conf配置文件[data_manage]中的remove_class_dir和remove_class字段来完成。
+
+ 2. 类别修改
+
+    该功能和上边的类别删除相似,需要修改对应路径的配置文件[data_manage]中的update_class_dir、update_new_class和update_old_class字段来完成。
+
+ 3. 转换图片格式为jpg
+
+    该功能可以批量将一些常见的图像格式转换成jpg格式。具体支持转换的图片格式见代码(支持36种)程序会引导用户填写图片所在的路径
+
+
+
+​		
+
+​		
+
+​		
+
+​	
+
+​		
+
+​		
+
+​	
+
+
+
+
+
+​	
+
+​	
+
+
+
+​	
+
+
+
+
+
+​		
+
+
+
+​            

BIN
code/data_manage/README.pdf


+ 24 - 0
code/data_manage/config/baidupan.conf

@@ -0,0 +1,24 @@
+###############################注意:以下所有的网盘文件夹路径都是以/app/bypy目录为根路径,所以在填写网盘地址时只需要写入相对路径即可#####################
+###############################有且只能以/app/bypy/为根目录,否则程序会无法顺利完成下载和上传工作################################################
+###############################例如:要将文件保存在/app/bypy/images目录下,只需要写入images即可。##############################################
+
+# 文件上传
+[upload]
+# 本地文件夹路径(需要上传的文件夹)
+localpath = C:\Users\Administrator\Desktop\Dell_IDC\Dell_IDC
+# 网盘上传文件夹路径
+remotepath = Dell_IDC/images
+
+
+# 文件下载
+[download]
+# 网盘下载文件夹路径
+remotepath = images
+# 本地文件夹路径(需要下载的本地路径)
+localpath = C:\Users\Administrator\Desktop\aaa\image
+
+
+# 创建分享链接
+[createShareLink]
+# 需要分享的文件夹
+remotepath = Dell_IDC/images

+ 105 - 0
code/data_manage/config/data_manage.conf

@@ -0,0 +1,105 @@
+# 视频转图片程序需要用到的配置文件
+# step1
+##############################################################################################
+[vidoe2image]
+# 需要读取的视频位置(绝对路径)
+# 路径直接复制win版路径即可,不需要转义字符,代码会自动处理路径表示的问题
+video_dir = C:\Users\Administrator\Desktop\aaa
+# 需要保存图片的位置(绝对路径)
+save_dir = C:\Users\Administrator\Desktop\aaa\image
+# 本次操作保存图片文件的前缀名  注(命名规则:场景地点+视频名+时间+帧数+'.jpg' 其中 场景地点 手工写入配置文件,其他则由代码生成)
+file_name = image_Wallseepage
+# 需要转换的视频文件的后缀名。多种格式用英文逗号隔开
+vidoes_extension = mp4,mov,MOV
+
+# 一下两个变量分别代表需要截取的图片总数和帧数,其中帧数代表每个一定的帧数截取一张图片
+# all_image_num:单个视频截取图片总数
+# frame_num:帧数
+# *** all_image_num和frame_num变量只能取其一,另一个赋值为零即可,如果两个变量都大于零,则程序会优先按frame_num变量逻辑切分
+# *** 如果变量取负数,代码则会取其绝对值
+all_image_num = 0
+frame_num = 20
+
+
+# step2-step7
+##############################################################################################
+[data_manage]
+# 选择生成普通txt文件还是yolo5特定的txt文件。flag:txt or yolo
+flag = yolo
+
+# 数据所在目录
+dir = /data2/object_detection/data/image/gushankuangye
+#dir = d:\Users\Administrator\Desktop\liudan\data
+# 存放xml和jpg的文件夹
+data_dir = total_data
+# 存放没有对应xml或jpg文件的文件夹
+redundant_dir = redundant_data
+# 标注种类的文件夹
+index_dir = class_index
+# 训练数据的xml和jpg的文件夹
+train_dir = train
+#train_dir = train_data
+# 测试数据的xml和jpg的文件夹
+#val_dir = val_data
+val_dir = val
+# 存放图片信息的csv文件夹
+label_csv_dir = label_csv
+# 存放图片信息的txt文件夹
+label_txt_dir = label_txt
+
+#相关数据集的文件名
+total_data_csv = total_data.csv
+total_data_txt = total_data.txt
+#total_data_yolo_txt = yolo_txt/total_data
+total_data_yolo_txt = yolo/total_data
+train_data_csv = train_data.csv
+train_data_txt = train_data.txt
+#train_data_yolo_txt = yolo_txt/train_data
+train_data_yolo_txt = yolo/train
+val_data_csv = val_data.csv
+val_data_txt = val_data.txt
+#val_data_yolo_txt = yolo_txt/val_data
+val_data_yolo_txt = yolo/val
+
+# 存放标注种类的json文件
+class_index_json = class_index/ob_classes.json
+# 存放标注种类的txt文件
+class_index_txt = class_index/ob_classes.txt
+# 存放不同种类图片的文件夹
+class_img_dir = class_img_dir
+# 存放tfrecord格式的文件夹
+tfrecord = tf_record
+# 存放标注种类的pbtxt
+pbtxt_path = ob.pbtxt
+# 测试集所占比例
+split_ratio : 0.15
+# 数据分布图
+data_distribution = data_distribution.jpg
+
+##gen_pb_tfrecord_step7.py##
+# val_data.csv文件存储位置
+val_csv_input = label_csv/val_data.csv
+# 生成val_data.record文件的存储位置
+val_output_path = tf_record/val_data.record
+# train_data.csv文件存储位置
+train_csv_input = label_csv/train_data.csv
+# 生成train_data.record文件的存储位置
+train_output_path = tf_record/train_data.record
+
+# 删除指定class小工具
+remove_class_dir = total_data
+remove_class = S18-1
+# 更新指定class小工具
+update_class_dir = total_data
+# 需要修改的类
+update_old_class = E3-4
+# 新的类
+update_new_class = E3-3
+
+##############################################################################################
+# [database]
+# 数据所在目录
+# project_path = C:\Users\Administrator\Desktop\7.15
+# 存放xml和jpg的文件夹
+# total_data = total_data
+

+ 72 - 0
code/data_manage/config/utils_db.conf

@@ -0,0 +1,72 @@
+##############################################################################################
+[database]
+# 数据所在目录
+project_path = /data2/fengyang/sunwin/data/image/maanshan_gangteichang
+# 存放xml和jpg的文件夹
+total_data = total_data
+
+#连接
+[connect]
+dbhost=192.168.20.249
+dbport=3306
+dbuser=root
+dbpassword=root
+#数据库名称
+dbname=db_img
+dbcharset=utf8
+#创建数据库的语句所在文件路径,需要新建数据库中的表时可用
+sql_path = /database/utils/utils_sql_create_DBImg
+
+# 插入
+[insert]
+class_state_insert = INSERT IGNORE  INTO class_state (state_name,class_name,is_common) VALUES (%%s,%%s,%%s)
+img_basic_insert = INSERT IGNORE  INTO img_basic (id,path,filename,width,height,depth,spot_id) VALUES (%%s,%%s,%%s,%%s,%%s,%%s,%%s)
+img_bbox_insert = INSERT IGNORE INTO img_bbox (id,file_id,xmin,ymin,xmax,ymax,class_name,state_id,bbox_ratio,spot_id) VALUES (%%s,%%s,%%s,%%s,%%s,%%s,%%s,%%s,%%s,%%s)
+//img_bbox_insert = INSERT IGNORE INTO img_bbox (id,file_id,xmin,ymin,xmax,ymax,class_name,state_name,bbox_ratio,spot_id) VALUES (?,?,?,?,?,?,?,?,?)
+scene_basic_insert = INSERT IGNORE INTO scene_basic (scene,description) VALUES (%%s,%%s)
+spot_basic_insert = INSERT IGNORE INTO spot_basic (scene_id, spot, spot_CN, project_path) VALUES (%%s,%%s,%%s,%%s)
+
+# 删除
+[delete]
+#更新数据库时,要更新写入的csv文件所在目录
+img_basic_delete = DELETE FROM img_basic where spot_id=%%s
+img_bbox_delete = DELETE FROM img_bbox where spot_id=%%s
+img_bbox_class_delete = DELETE FROM img_bbox where spot_id=%%s and state_name='%%s'
+
+# 选择
+[select]
+spot_basic_spot_id_select = select spot_id from spot_basic where spot='%%s'
+spot_basic_select = select spot_id, scene_id, spot, spot_CN, project_path from spot_basic
+scene_basic_select = select scene_id,scene from scene_basic
+img_basic_id_select = SELECT id FROM img_basic WHERE path='%%s'
+class_state_is_true_or_false_select = SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.is_common=%%s
+class_name_and_spot_id_select = SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.class_name='%%s' and ibb.spot_id=%%s
+class_name_select = SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.class_name='%%s'
+spot_id_select = SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where ibb.spot_id=%%s
+
+img_bbox_x_y_select = SELECT file_id, state_name, xmin, ymin, xmax, ymax FROM img_bbox WHERE state_name='%%s'
+img_basic_w_h_d_select = SELECT id, path, filename, width, height, depth FROM img_basic WHERE id=%%s
+
+# 更新
+[update]
+img_bbox_class_update = UPDATE img_bbox SET class_name='%%s',state_name='%%s' WHERE spot_id='%%s' AND state_name='%%s'
+
+# 是否通用标签
+[common]
+# 字段中填入非通用的class大类即可。例如S1-0 中的s类为非通用类。则只需要填写S即可。字符之间用英文逗号隔开
+is_common = S,R
+
+[gen_tfrecord]
+#字段对应的值为:class_name1,spot_id1,percent1;class_name2,spot_id2,percent2;...
+# class_name为类别名称(不是state_name状态名称),如果要获取全部的class_name,则填入all
+# spot_id(int类型)为地点ID需要去数据库的spot_basic表中查找,如果要获取所有spot_id,则填写all
+# percent(float)需要获取数据的百分比(0-1],如果获取全部数据,则填写1
+#数组之间用|符号隔开
+class_name_and_spot = E1,1,0.2|K10,1,0.3|all,1,1
+
+
+
+
+
+
+

+ 89 - 0
code/data_manage/gen_data/check_data_step2.py

@@ -0,0 +1,89 @@
+"""
+# File       : check_data_step2.py
+# Time       :21.5.29 12:26
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:对标注公司标注好的文件进行二次质检,确保xml文件和jpg文件一一对应 step2_检查是否存在漏标的图片
+"""
+import glob
+import shutil
+import sys
+
+from gen_data.utils import profileUtil, pathUtil, fileUtil  # 配置文件操作、路径操作、文件筛选工具。
+
+class checkFile:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.project_path = pathUtil.path_format(all_items["dir"])
+        self.redundant_dir = pathUtil.path_format_join(all_items["dir"], all_items["redundant_dir"])  # 问题xml文件和jpg文件保存路径
+        self.file_names = pathUtil.path_format_join(all_items["dir"], all_items["data_dir"])  # 图片保存路径
+
+    def diff_check(self, list1, list2, file_type):
+        """
+        检查list1中有而list2中没有的文件,将问题文件移动到redundant_dir文件夹中,并返回问题文件的绝对路径
+        :param list1:列表
+        :param list2:列表
+        :param file_type:文件类型
+        :return:由问题文件路径组成的列表
+        """
+        problem_list = []
+        diff_list = set(list1).difference(list2)  # 检查list1中有而list2中没有的文件(差集)
+        for diff in diff_list:  # 遍历列表中所有元素
+            pronlem_file_name = diff+file_type
+            pronlem_file_path = pathUtil.path_format_join(self.file_names, pronlem_file_name)
+            move_file_path = pathUtil.path_format_join(self.redundant_dir, pronlem_file_name)
+            problem_list.append(pronlem_file_path)  # 列表末尾添加新的对象
+            shutil.move(pronlem_file_path, move_file_path)  # 移动文件
+        if len(problem_list)>0:
+            if file_type == '.xml':
+                problem_list.extend([{'.xml文件缺少对应的.jpg文件': problem_list}])  # 一次性追加另一个序列中的多个值
+                print('这些.xml文件缺少对应的.jpg文件:%s'% (problem_list))
+            else:
+                problem_list.extend([{'.jpg文件缺少对应的.xml文件': problem_list}])
+                print('这些.jpg文件缺少对应的.xml文件:%s' % (problem_list))
+        return problem_list
+
+
+    def check_file(self, xml_name_list, jpg_name_list):
+        """
+        筛选出xml和jpg无法一一对应的问题文件路径。并组成列表并返回
+        :param xml_name_list: xml文件的列表
+        :param jpg_name_list: jpg文件的列表
+        :return: 问题文件列表
+        """
+        diff1 = self.diff_check(xml_name_list, jpg_name_list, '.xml')
+        diff2 = self.diff_check(jpg_name_list, xml_name_list, '.jpg')
+        problem_list = diff1 + diff2
+        return problem_list
+
+    def main(self):
+
+        xml_name_list = [pathUtil.path_format(file).split('/')[-1].split('.xml')[0] for file in glob.glob(self.file_names + '/*.xml')]
+        jpg_name_list = [pathUtil.path_format(file).split('/')[-1].split('.jpg')[0] for file in glob.glob(self.file_names + '/*.jpg')]
+        if len(xml_name_list)+len(jpg_name_list) < 1:
+            print('没有找相应的数据,请检查 %s 路径数据'% self.file_names)
+            sys.exit(-1)
+        pathUtil.mkdir_new(self.redundant_dir)
+        problem_list = self.check_file(xml_name_list, jpg_name_list)
+        if problem_list:
+            fileUtil.writelog(problem_list)
+            print('问题文件的存放地址为:%s'%(self.redundant_dir))
+        else:
+            print('检验完毕,xml文件和jpg文件正常!')
+        print('\n-----------------------------step2完成-----------------------------\n')
+
+if __name__ == "__main__":
+    checkFile().main()
+
+
+
+
+
+
+
+
+
+
+

+ 142 - 0
code/data_manage/gen_data/classFeatureImages_step3.py

@@ -0,0 +1,142 @@
+
+"""
+# File       : classFeatureImages_step3.py
+# Time       :21.5.28 14:35
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:将标注后的图像中的特征图像提取出来后按照类别分类 step3_截取图片中标注后的特征图像,并对其分类处理
+"""
+import math
+import cv2
+import shutil
+import time
+import numpy as np
+from tqdm import tqdm
+from gen_data.utils import coorUtil, profileUtil, pathUtil, fileUtil
+import threading
+# import matplotlib.pyplot as plt
+
+class classFeatureImages:
+
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.redundant_dir = pathUtil.path_format_join(all_items["dir"], all_items["redundant_dir"])
+        self.img_dir = pathUtil.path_format_join(all_items["dir"], all_items["data_dir"]) # 图片保存路径
+        self.class_img_dir = pathUtil.path_format_join(all_items["dir"], all_items["class_img_dir"]) # 分类好的图片保存路径
+
+
+    def cv_imread(self, filePath):
+        """
+        读取图像,该方法解决了cv2.imread()不能读取中文路径的问题
+        :param filePath: 文件路径
+        :return:
+        """
+        cv_img = cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), -1)
+        # 解决cv2.imread无法读取中文路径的问题,原因是opencv不接受non-ascii的路径
+        # 先用np.fromfile()读取为np.uint8格式,再使用cv2.imdecode()解码
+        return cv_img
+
+    # def cv_imread(self, filePath):
+    #     # 用matplotlib的路径
+    #     img = plt.imread(filePath)
+    #     # 因为opencv读取是按照BGR的顺序,所以这里转换一下即可
+    #     cv_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+    #     return cv_img
+
+    # def cv_imread(self, filePath=""):
+    #     file_path_gbk = filePath.encode('gbk')  # unicode转gbk,字符串变为字节数组
+    #     cv_img = cv2.imread(file_path_gbk.decode('gbk'))  # 字节数组直接转字符串,不解码
+    #     return cv_img
+
+
+    @staticmethod
+    def splitdf(df, num):  #
+        linenum = math.floor(len(df) / num)
+        # math.floor:向下取整,len(df)是总标记数
+        pdlist = []
+        for i in range(num):  # 1-num不包括num依次赋值给i
+            pd1 = df[i * linenum:(i + 1) * linenum]
+            pdlist.append(pd1)
+        #         print(len(pd1))
+        pd1 = df[(num - 1) * linenum:len(df)]
+        pdlist.append(pd1)
+        return pdlist
+
+    def isexists_class(self, df, class_img_dir):
+        """
+        通过['class_name']字段先获取所有图像的类别,然后判断该类别文件夹是否存在,如果不存在则新建该类别文件夹
+        :param df: Dataframe
+        :param class_img_dir:类别文件夹的父目录
+        :return:
+        """
+        group_df = df.groupby('class_name')  # 以'class_name'进行分组
+        for k, _ in group_df:
+            class_dir = pathUtil.path_format_join(class_img_dir, str(k))
+            pathUtil.mkdir(class_dir)
+
+
+
+    def gen_class_img_thread(self, csv_df, img_dir, class_img_dir):
+        self.isexists_class(csv_df, class_img_dir)
+        threads = []  # 定义一个线程池,把要运行的线程都写到这个线程池列表里:
+        df_list = self.splitdf(csv_df, 10)
+        for df in df_list:
+            t =threading.Thread(target=self._gen_class_img, args=(df, img_dir, class_img_dir))
+            # 多线程,target是线程函数变量,args是数组变量参数(此处是3个:df,img_dir,class_img_gir)
+            threads.append(t)  # 把t线程装到线程池里面
+        for t in threads:
+            t.setDaemon(True)  # 声明t为守护线程,设置为Ture,则主线程和子线程一起运行
+            t.start()  # 用一个for语句遍历threads里的线程,然后调用start()方法运行
+        for t in threads:
+            t.join()  # t.join需放在for循环外面,作用是执行完所有子线程才去执行主线程
+            time.sleep(10)  # 暂停执行的秒数
+        print('/n标注框总数为:', len(csv_df))
+
+    def _gen_class_img(self, csv_df, img_dir, class_img_dir):
+        """
+        将所有图片中标注出来的特征图片切分出来存放入对应类别的文件夹
+        :param img_dir: 需要读取原始图片的路径
+        :param class_img_dir: 存放分类的特征图片保存路径的父路径,与class_name拼接后生成保存图像文件的绝对路径
+        :return:
+        """
+        errors = []
+        for index, row in tqdm(csv_df.iterrows(),total=len(csv_df), ncols=60, position=0): # position=0可以避免进度条出现多行的情况。
+            # index是行索引值,row是对应的行内容,total是总的项目数,ncols为进度条长度
+            filename, class_name = row["filename"], row["class_name"]
+            image_path = pathUtil.path_format_join(img_dir, filename)
+            error, error_dic = coorUtil.check_coor(image_path, row)
+            if error:
+                old_file_path  = '.'.join(image_path.split('.')[0:-1])
+                #  'sep'.join(seq):seq为分隔符,以sep作为分隔符,将seq所有的元素合并成一个新的字符串
+                # split('.')[0:-1]:以'.'为分隔符,取序列为0到最大的项,也就是全部项?
+                new_file = '.'.join(image_path.split('\\')[-1].split('/')[-1].split('.')[0:-1])
+                new_file_path = pathUtil.path_format_join(self.redundant_dir,new_file)
+                print(new_file_path)
+                try:
+                    shutil.move(old_file_path+'.jpg', new_file_path+'.jpg')
+                    shutil.move(old_file_path+'.xml', new_file_path+'.xml')
+                    errors.extend([error_dic])
+                except:
+                    pass
+                continue
+            xmin, ymin, xmax, ymax = row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+            class_file = pathUtil.path_format_join(class_img_dir, "{}".format(row["class_name"]))
+            image = self.cv_imread(image_path)
+            cropimg = image[int(ymin):int(ymax), int(xmin):int(xmax)]  # 裁剪
+            img_path = pathUtil.path_format_join(class_file, filename.split('.jpg')[0] + '_' + str(xmin) + '_' + str(ymin) + '_' + str(xmax) + '_' + str(ymax) + '.jpg')
+            cv2.imwrite(img_path, cropimg)
+        if errors:
+            print('标注图像有问题:', errors)
+            fileUtil.writelog(errors)
+
+    def main(self):
+        pathUtil.mkdir_new(self.class_img_dir)
+        csv_df = profileUtil.xmlUtil().xml_parse(self.img_dir)
+        self.gen_class_img_thread(csv_df, self.img_dir,self.class_img_dir)
+        print('\n-----------------------------step3完成-----------------------------\n')
+
+
+if __name__ == '__main__':
+    classFeatureImages().main()

+ 67 - 0
code/data_manage/gen_data/class_operate.py

@@ -0,0 +1,67 @@
+"""
+# File       : class_operate.py
+# Time       :21.6.9 13:42
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:小工具功能:将本地xml文件中多余的类删除,或者是修改已经存在的旧类
+"""
+import glob
+from gen_data.utils import pathUtil, profileUtil, strUtil
+import xml.etree.ElementTree as ET
+
+class classOperate:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.remove_class_dir = pathUtil.path_format_join(all_items["dir"], all_items["remove_class_dir"])
+        self.remove_class_str = strUtil.profile2str(all_items["remove_class"])
+        self.update_class_dir = pathUtil.path_format_join(all_items["dir"], all_items["update_class_dir"])
+        self.update_old_class = strUtil.profile2str(all_items["update_old_class"])
+        self.update_new_class = strUtil.profile2str(all_items["update_new_class"])
+
+    def remove_class(self):
+        """
+
+        :return:
+        """
+        count_file, count_claass = 0, 0
+        flag = input("确认需要删除%s下的%s类吗?请输入 Y 确认。"%(self.remove_class_dir, self.remove_class_str))
+        if flag == 'Y' or flag == 'y':
+            input("请确认已正确修改data_manage.conf文件下的[remove_class_dir]、[remove_class]配置选项,请按回车键继续...")
+            for filename in glob.glob(self.remove_class_dir + '/*.xml'):
+                c = 0
+                dom = ET.parse(filename)
+                root = dom.getroot()
+                for obj in root.findall('object'):
+                    if obj.find('name').text == self.remove_class_str:
+                        c = 1
+                        count_claass += 1
+                        root.remove(obj)
+                count_file += c
+
+                # 保存到指定文件
+                dom.write(filename, xml_declaration=True)
+            print('共 %d 个xml文件中涉及该类数据, 总数据量为 %d,已全部删除' % (count_file, count_claass))
+        else:
+            print('程序已取消运行。')
+
+
+    def update_class(self):
+        count = 0
+        flag = input("确认需要将%s下的%s类修改为%s类吗?请输入 Y 确认。" % (self.update_class_dir, self.update_old_class, self.update_new_class))
+        if flag == 'Y' or flag == 'y':
+            input("请确认已正确修改data_manage.conf文件下的[update_class_dir]、[update_old_class]、[update_new_class]配置选项,请按回车键继续...")
+            for filename in glob.glob(self.update_class_dir + '/*.xml'):
+                dom = ET.parse(filename)
+                root = dom.getroot()
+                for obj in root.findall('object'):
+                    if obj.find('name').text == self.update_old_class:
+                        obj.find('name').text = self.update_new_class
+                        count = count + 1
+
+                # 保存到指定文件
+                dom.write(filename, xml_declaration=True)
+            print("有 %d 个xml文件中的class[%s]被成功修改为class[%s]。" % (count, self.update_old_class, self.update_new_class))
+        else:
+            print('程序已取消运行。')

+ 43 - 0
code/data_manage/gen_data/genAnn_step5.py

@@ -0,0 +1,43 @@
+"""
+# File       : genAnn_step5.py
+# Time       :21.5.31 9:10
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:生成图片的具体信息:文件路径、特征图像的坐标等关键信息 step5_生成图片信息文件
+"""
+import time
+from gen_data.utils import profileUtil, pathUtil, strUtil
+
+class ganAnn:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.data_dir = pathUtil.path_format_join(all_items["dir"], all_items["data_dir"])  # 存放xml和jpg的文件路径
+        self.label_csv_dir = pathUtil.path_format_join(all_items["dir"], all_items["label_csv_dir"])  # 存放图片信息的csv文件路径
+        self.label_txt_dir = pathUtil.path_format_join(all_items["dir"], all_items["label_txt_dir"])  # 存放图片信息的txt文件路径
+        self.class_index_json = pathUtil.path_format_join(all_items["dir"], all_items["class_index_json"])  # 存放标注种类的json文件路径
+        self.total_data_txt = pathUtil.path_format_join(self.label_txt_dir, all_items["total_data_txt"])  # 存放图片数据的txt文件路径
+        # self.total_data_yolo_txt = pathUtil.path_format_join(all_items["dir"], all_items["total_data_yolo_txt"])
+        self.total_data_csv = pathUtil.path_format_join(self.label_csv_dir, all_items["total_data_csv"])
+        self.class_index_json = pathUtil.path_format_join(all_items["dir"], all_items["class_index_json"])
+        self.flag = strUtil.profile2str(all_items["flag"])
+        self.data_distribution = pathUtil.path_format_join(all_items["dir"], all_items["data_distribution"])
+
+
+    def main(self):
+        pathUtil.mkdir_new(self.label_csv_dir)
+        pathUtil.mkdir_new(self.label_txt_dir)
+        # pathUtil.mkdir_new(self.total_data_yolo_txt)
+        xml_util = profileUtil.xmlUtil()
+        # if self.flag.lower() == 'yolo':
+            # print(self.total_data_yolo_txt)
+            # xml_util.xml_to_yolo_txt(self.total_data_yolo_txt, self.data_dir, self.class_index_json)
+        # else:
+        xml_util.xml_to_txt(self.total_data_txt, self.data_dir, self.class_index_json)
+        xml_util.xml_to_csv(self.total_data_csv, self.data_dir, self.data_distribution)
+        print('\n-----------------------------step5完成-----------------------------\n')
+        time.sleep(0.5)
+
+if __name__ == "__main__":
+    ganAnn().main()

+ 70 - 0
code/data_manage/gen_data/gen_class_index_step4.py

@@ -0,0 +1,70 @@
+"""
+# File       : gen_class_index_step4.py
+# Time       :21.5.29 10:10
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:生成图片的所有分类信息 step4_生成类别信息文件
+"""
+import json
+from gen_data.utils import fileUtil, profileUtil, pathUtil
+import os
+
+class genClassIndex:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+
+        self.data_dir = pathUtil.path_format_join(all_items["dir"], all_items["data_dir"])  # 存放xml和jpg的文件夹
+        self.index_dir = pathUtil.path_format_join(all_items["dir"], all_items["index_dir"])  # 存放标注种类的文件夹
+        self.class_index_txt_path = pathUtil.path_format_join(all_items["dir"], all_items["class_index_txt"])  # 存放标注种类的txt文件
+        self.class_index_json_path = pathUtil.path_format_join(all_items["dir"], all_items["class_index_json"]) # 存放标注种类的json文件
+        self.dir_path = pathUtil.path_format_join(all_items["dir"], all_items["dir"])
+
+    def gen_class_index_json_txt(self, xml_df, class_index_json_path, class_index_txt_path, dir_path):
+        """
+        将分好的class_name分别写入json文件和txt文件
+        :param data_dir: 存放xml和jpg的文件夹
+        :param class_index_json_path: 存放标注种类的json文件夹
+        :param class_index_txt_path: 存放标注种类的txt文件夹
+        :return:
+        """
+        class_index_dict = dict()
+        class_name_set = set(xml_df["class_name"].values)  # set是集合,values是键值
+        class_name_list = list(class_name_set)
+        class_name_list.sort()  # 排序
+        for index, class_name in enumerate(class_name_list):  # 用于将一个可遍历数据对象(列表、元组或者字符串)组合为一个索引序列,同时列出数据和数据下标,通常在for循环中
+            class_index_dict[class_name] = index
+
+        # save json file
+        with open(class_index_json_path, "w") as json_file:
+            json_file.write(json.dumps(class_index_dict, indent=4))  # 把一个Python对象编码转换成Json字符串 json.dumps( )
+        print("总类别数: %d" % (len(class_name_list)))
+        print("写入 %s 完成" % (class_index_json_path))
+        #f = open(dir_path + "\\classes.txt", 'a')
+        #classresult = ''
+        #for i in class_name_list:
+            #classresult = classresult + i + "\n"
+        #f.write(classresult)
+        #f.close()
+
+
+        # save txt file
+        with open(class_index_txt_path, "w") as txt_file:
+            index_class_dict = {value: key for key, value in class_index_dict.items()}
+            for i in range(len(index_class_dict)):
+                txt_file.write(index_class_dict[i])
+                txt_file.write("\n")
+
+        print("写入 %s 完成" % (class_index_txt_path))
+
+
+    def main(self):
+        pathUtil.mkdir_new(self.index_dir)
+        xml_df = profileUtil.xmlUtil().xml_parse(self.data_dir)
+        self.gen_class_index_json_txt(xml_df, self.class_index_json_path, self.class_index_txt_path, self.dir_path)
+        print('\n-----------------------------step4完成-----------------------------\n')
+
+
+if __name__ == "__main__":
+    genClassIndex().main()

+ 145 - 0
code/data_manage/gen_data/gen_pb_tfrecord_step7.py

@@ -0,0 +1,145 @@
+"""
+# File       : gen_pb_tfrecord_step7.py
+# Time       :21.6.4 10:10
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:生成pbtxt文件和Tfrecord文件 step7_生成特定的文件
+"""
+import os
+import io
+import json
+import pandas as pd
+import tensorflow.compat.v1 as tf
+from PIL import Image
+from collections import namedtuple
+from research.object_detection.utils import dataset_util
+from research.object_detection.utils import label_map_util
+from gen_data.utils import profileUtil, pathUtil, strUtil, fileUtil
+
+class genPb:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.class_json_path = pathUtil.path_format_join(all_items["dir"], all_items["class_index_json"])
+        self.pbtxt_path = pathUtil.path_format_join(all_items["dir"], all_items["pbtxt_path"])
+
+    def gen_pbtxt(self):
+        with open(self.class_json_path, 'r') as json_file:
+            json_dict = json.load(json_file)
+        json_len = len(json_dict)
+        total_list = ["" for i in range(json_len)]
+        for key, value in json_dict.items():
+            line_content_list = []
+            line_content_list.append("item {\n")
+            line_content_list.append("  id: {}\n".format(value + 1))
+            line_content_list.append("  name: '{}'\n".format(key))
+            line_content_list.append("}\n")
+            line_content_list.append("\n")
+            fill_content = "".join(line_content_list)
+            total_list[value] = fill_content
+
+        with open(self.pbtxt_path, 'w') as pbtxt_file:
+            for i in total_list:
+                pbtxt_file.write(i)
+
+    def main(self):
+        self.gen_pbtxt()
+        print('成功创建%s' % self.pbtxt_path)
+
+class genTfrecord:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.tfrecord = pathUtil.path_format_join(all_items["dir"], all_items["tfrecord"])
+
+        flags = tf.app.flags  # 定义一个用于接收string类型数值的变量;
+        flags.DEFINE_string("val_csv_input", pathUtil.path_format_join(all_items["dir"], all_items["val_csv_input"]),
+                            "Path to the CSV input")
+        # flags.DEFINE_string 定义一个名称为‘val_csv_input'的变量,默认值是这个变量的路径,描述信息表明CSV输出的路径
+        flags.DEFINE_string("images_input", pathUtil.path_format_join(all_items["dir"], all_items["data_dir"]),
+                            "Path to the images input")
+        flags.DEFINE_string("val_output_path",
+                            pathUtil.path_format_join(all_items["dir"], all_items["val_output_path"]),
+                            "Path to output TFRecord")
+        flags.DEFINE_string("label_map_path", pathUtil.path_format_join(all_items["dir"], "ob.pbtxt"),
+                            "Path to label map proto")
+        flags.DEFINE_string("train_csv_input",
+                            pathUtil.path_format_join(all_items["dir"], all_items["train_csv_input"]),
+                            "Path to the CSV input")
+        flags.DEFINE_string("train_output_path",
+                            pathUtil.path_format_join(all_items["dir"], all_items["train_output_path"]),
+                            "Path to output TFRecord")
+        self.FLAGS = flags.FLAGS
+
+
+    def split(self, df, group):
+        data = namedtuple("data", ["filename", "object"])  # 定义一个namedtuple类型data,并包含了’filename‘和‘object’属性
+        gb = df.groupby(group)  # 分组
+        return [data(filename, gb.get_group(x)) for filename, x in
+                zip(gb.groups.keys(), gb.groups)]
+
+
+    def create_tf_example(self, group, label_map_dict, images_path):
+        #  tf.gfile.GFile(filename, mode)  获取文本操作句柄,类似于python提供的文本操作open()函数,
+        #  filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。
+        with tf.gfile.GFile(os.path.join(
+                images_path, "{}".format(group.filename)), "rb") as fid:  # 读取图片到内存
+            encoded_jpg = fid.read()
+        encoded_jpg_io = io.BytesIO(encoded_jpg)  # 图片是Byte对象
+        image = Image.open(encoded_jpg_io)
+        width, height = image.size
+
+        filename = group.filename.encode("utf8")
+        image_format = b"jpg"
+        xmins, xmaxs, ymins, ymaxs, classes_text, classes = [], [], [], [], [], []
+        for index, row in group.object.iterrows():
+            xmins.append(row["xmin"] / width)
+            xmaxs.append(row["xmax"] / width)
+            ymins.append(row["ymin"] / height)
+            ymaxs.append(row["ymax"] / height)
+            classes_text.append(str(row['class_name']).encode("utf8"))
+            classes.append(label_map_dict[str(row['class_name'])])
+
+        tf_example = tf.train.Example(features=tf.train.Features(feature={  # 建立example,并由若干个features的字典组成
+            "image/height": dataset_util.int64_feature(height),
+            "image/width": dataset_util.int64_feature(width),
+            "image/filename": dataset_util.bytes_feature(filename),
+            "image/source_id": dataset_util.bytes_feature(filename),
+            "image/encoded": dataset_util.bytes_feature(encoded_jpg),
+            "image/format": dataset_util.bytes_feature(image_format),
+            "image/object/bbox/xmin": dataset_util.float_list_feature(xmins),
+            "image/object/bbox/xmax": dataset_util.float_list_feature(xmaxs),
+            "image/object/bbox/ymin": dataset_util.float_list_feature(ymins),
+            "image/object/bbox/ymax": dataset_util.float_list_feature(ymaxs),
+            "image/object/class/text": dataset_util.bytes_list_feature(classes_text),
+            "image/object/class/label": dataset_util.int64_list_feature(classes),
+        }))
+        return tf_example
+
+
+    def gen_tfrecord(self, val_output_path, val_csv_input):
+        # generate val_tfrecord
+        writer = tf.python_io.TFRecordWriter(val_output_path)
+        label_map_dict = label_map_util.get_label_map_dict(self.FLAGS.label_map_path)
+        images_path = self.FLAGS.images_input
+        examples = pd.read_csv(val_csv_input)
+        grouped = self.split(examples, "filename")
+        for group in grouped:
+            tf_example = self.create_tf_example(group, label_map_dict, images_path)
+            writer.write(tf_example.SerializeToString())  # 将example序列化 并写入TFRecords 文件
+
+        writer.close()
+        print("成功创建 %s" % val_output_path)
+
+
+    def main(self):
+        pathUtil.mkdir_new(self.tfrecord)
+        self.gen_tfrecord(self.FLAGS.val_output_path, self.FLAGS.val_csv_input)
+        self.gen_tfrecord(self.FLAGS.train_output_path, self.FLAGS.train_csv_input)
+        print('\n-----------------------------step7完成-----------------------------\n')
+
+
+if __name__ == "__main__":
+    # genPb().main()
+    genTfrecord().main()

+ 65 - 0
code/data_manage/gen_data/other2jpg.py

@@ -0,0 +1,65 @@
+"""
+# File       : other2jpg.py
+# Time       :21.6.23 10:29
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:将
+"""
+
+import glob
+from PIL import Image
+import os
+from gen_data.utils import pathUtil
+import xml.etree.ElementTree as ET
+
+
+def convert(source, tmp_dir_path):
+    """
+
+    :param source:
+    :param target:
+    :return:
+    """
+    im = Image.open(source)
+    file_ = os.path.splitext(source)[0]
+    name_ = os.path.splitext(source)[0].split('/')[-1]
+    xmlfile = file_+'.xml'
+    if os.path.exists(xmlfile):
+        dom = ET.parse(xmlfile)
+        root = dom.getroot()
+        root.find('filename').text = name_+ '.jpg'
+        root.find('path').text = os.path.splitext(root.find('path').text)[0] + '.jpg'
+
+        # 保存到指定文件
+        xmlfile_path = pathUtil.path_format_join(tmp_dir_path, os.path.splitext(source)[0].split('/')[-1]) + '.xml'
+        dom.write(xmlfile_path, xml_declaration=True)
+
+
+    jpgpath = pathUtil.path_format_join(tmp_dir_path, os.path.splitext(source)[0].split('/')[-1]) + '.jpg'
+
+    im.save(jpgpath)
+
+def main():
+
+    supports = ['bmp', 'dib', 'gif', 'tif', 'tiff', 'jfif', 'jpe', 'jpeg', 'pbm', 'pgm', 'ppm',
+                'pnm', 'png', 'apng', 'pcx', 'ps', 'eps', 'jp2', 'j2k', 'jpc', 'jpf', 'jpx', 'j2c', 'ico', 'im',
+                'mpo', 'pdf', 'bw', 'rgb', 'rgba', 'sgi', 'tga', 'icb', 'vda', 'vst', 'webp']
+    flag = True
+    while flag:
+        path = input('请输入需要转换图片格式目录路径(输入0结束程序):')
+        dir_path = pathUtil.path_format(path)
+        if dir_path == '0':
+            exit(0)
+        elif os.path.exists(dir_path):
+            count = 0
+            tmp_dir_path = pathUtil.path_format_join(dir_path, '_temp_imgs')
+            pathUtil.mkdir_new(tmp_dir_path)
+            for support in supports:
+                for file_path in glob.glob(dir_path + '/*.' + support):
+                    convert(pathUtil.path_format(file_path), tmp_dir_path)
+                    count += 1
+            print('%d张图片转换完成,请在 %s 目录下查看' % (count, tmp_dir_path))
+            exit(0)
+        else:
+            print('路径不存在,请重新输入')

+ 240 - 0
code/data_manage/gen_data/splitTrainVal_step6.py

@@ -0,0 +1,240 @@
+"""
+# File       : splitTrainVal_step6.py
+# Time       :21.6.1 13:49
+# Author     :FEANGYANG
+
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:在根据类别信息文件对图片进行分类后的基础上进行训练集和测试集的划分操作 step6_训练集和验证集的划分
+"""
+import math
+import shutil
+import sys
+
+import pandas as pd
+from sklearn.model_selection import train_test_split
+from tqdm import tqdm
+
+from gen_data.utils import profileUtil, pathUtil, strUtil
+
+
+class splitTrainVal:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        self.project_name = pathUtil.path_format(all_items["dir"]).split('/')[-1]
+        self.data_dir = pathUtil.path_format_join(all_items["dir"], all_items["data_dir"])
+        label_csv_dir = pathUtil.path_format_join(all_items["dir"], all_items["label_csv_dir"])
+        label_txt_dir = pathUtil.path_format_join(all_items["dir"], all_items["label_txt_dir"])
+        self.train_dir = pathUtil.path_format_join(all_items["dir"], all_items["train_dir"])
+        self.val_dir = pathUtil.path_format_join(all_items["dir"], all_items["val_dir"])
+        self.class_index_json = pathUtil.path_format_join(all_items["dir"], all_items["class_index_json"])
+        self.total_data_csv = pathUtil.path_format_join(label_csv_dir, all_items["total_data_csv"])
+        self.split_ratio = all_items["split_ratio"]
+
+        self.flag = strUtil.profile2str(all_items["flag"])
+        self.train_data_txt = pathUtil.path_format_join(label_txt_dir, all_items["train_data_txt"])
+        self.train_data_yolo_txt = pathUtil.path_format_join(all_items["dir"], all_items["train_data_yolo_txt"])
+        self.train_data_csv = pathUtil.path_format_join(label_csv_dir, all_items["train_data_csv"])
+        self.val_data_txt = pathUtil.path_format_join(label_txt_dir, all_items["val_data_txt"])
+        self.val_data_yolo_txt = pathUtil.path_format_join(all_items["dir"], all_items["val_data_yolo_txt"])
+        self.val_data_csv = pathUtil.path_format_join(label_csv_dir, all_items["val_data_csv"])
+        self.class_index_txt = pathUtil.path_format_join(all_items["dir"], all_items["class_index_txt"])
+
+    def split_train_val(self):
+        """
+        切分训练集和测试集
+        首先根据文件名进行分组,然后确保同一个图片中不会出现相同的类别。此时遍历所有图片,一旦出现新的分类就加入字典中,并动态的将图片分到不同的类别下,进而保障了各个类别下图片分布的相对均匀
+
+        :return:
+        """
+        pathUtil.mkdir(self.train_dir)
+        pathUtil.mkdir(self.val_dir)
+        total_df = pd.read_csv(self.total_data_csv, engine='python')  # engine = 'python' 可解决读取csv文件中文路径出错
+        # total_df: filename width ... xmax ymax
+        # 由所有分类组成的字典
+        class_dict = dict()
+        # class_dict:{'E3-1': ['...-hd']}
+        # 根据图像文件名进行分组,并遍历
+        # key: '...-hd.jpg'  values;filename width ... xmin ymin xmax ymax
+        for key, values in tqdm(total_df.groupby('filename'), total=len(total_df.groupby('filename')),
+                                ncols=80):  # 进度条总长度=80
+            flag = True
+            key = key.split('.jpg')[0]  # 以.jpg为分隔符,并取第一个分片
+            # 去除该图像文件中重复的类别(确保每个分类在一张图片中只有一个)
+            values.drop_duplicates(subset=['class_name'], keep='first', inplace=True)
+            # df.drop_duplicates(subset=[],keep=‘ ’ inplace= ) subset即需要删除哪些列中重复的项,列用’'说明,用,隔开
+            # keep: 保留第一个(first)/最后一个(last)/不保留(False)重复的项。 inplace:默认为False是否在原数据上修改,False表示另存一个副本
+            # 由该图像文件中包含的所有分类组成的列表
+            class_name = []
+            # class_name: ['E3-1']
+            # 遍历该图像中的所有分类
+            for k, v in values.iterrows():  # 对表格进行遍历,返回元祖(k,v)
+                # k: 0  v: ('filename', '...-hd.jpg') ('width', 1920) ('height', 1080) ('depth', 3) ('class_name', 'E3-1) ('xmin', 803) ('ymin', 498) ('xmax', 879) ('ymax', 572)
+                # 分类字典中没有这个K
+                if class_dict.get(v['class_name']) == None:
+                    path_list = [key]
+                    # path_list: ['...-hd']
+                    class_dict[v['class_name']] = path_list
+                    # BUG优化: flag的作用是为了防止第二张图像被同时分到两个类别中的情况发生
+                    flag = False
+                    break
+                else:
+                    class_name.append(v['class_name'])
+            # 第一张图片按照其一个个类别划分
+            if len(class_name) == 1 and flag:  # len(list)表示列表元素个数
+                class_dict[class_name[0]].append(key)
+            # 后面的图片按照分类字典中各个分类所划分的文件数量最少者得的思路进行
+            elif len(class_name) > 1 and flag:
+                list_len = math.inf
+                min_class_key = None
+                for class_k in class_name:
+
+                    if list_len > len(class_dict[class_k]):
+                        list_len = len(class_dict[class_k])
+                        min_class_key = class_k
+                class_dict[min_class_key].append(key)
+
+            else:
+                continue
+        train_data, val_data = [], []
+        train_distribute = {}
+        val_distribute = {}
+        for key in class_dict:
+            clas = class_dict.get(key)
+            try:
+                split_ratio = float(self.split_ratio)
+            except:
+                print('conf文件中的split_ratio字段无法转换成数字,请检查')
+                sys.exit(1)
+            # print(key)
+            # print(clas)
+            # print("\n")
+
+            train_d, val_d = train_test_split(clas, random_state=2020, test_size=split_ratio,
+                                                  shuffle=True)  # 此处报错很可能是某个类别中只有一个标注目标,导致无法分出训练数据和测试数据。检查class_img_dir文件夹中的类别。
+            train_distribute[key] = len(train_d)
+            val_distribute[key] = len(val_d)
+            train_data.extend(train_d)
+            val_data.extend(val_d)
+
+        print('数据分布如下:')
+        print('train数据集:%s,val数据集:%s \n' % (train_distribute, val_distribute))
+
+        return train_data, val_data
+
+    # def total_copy_TrainVal(self, path_data, total_path, train_val_path):
+    #     """
+    #     拷贝jpg、xml文件到相应的train/val数据集文件夹下
+    #     :param path_data: 被划分的train/val数据集合
+    #     :param total_path: 总数据文件夹路径
+    #     :param train_val_path: 划分的train/val数据文件夹路径
+    #     :return:
+    #     """
+    #     for data in tqdm(path_data, total=len(path_data), ncols=80):
+    #         xml_path = pathUtil.path_format_join(total_path, str(data+'.xml'))
+    #         jpg_path = pathUtil.path_format_join(total_path, str(data+'.jpg'))
+    #         xml_new_path = pathUtil.path_format_join(train_val_path, str(str(xml_path).split('/')[-1]))
+    #         jpg_new_path = pathUtil.path_format_join(train_val_path, str(str(jpg_path).split('/')[-1]))
+    #         try:
+    #             shutil.copyfile(xml_path, xml_new_path)  # 复制文件到新的文件夹
+    #             shutil.copyfile(jpg_path, jpg_new_path)
+    #         except Exception as e:
+    #             print(e)
+    #             continue
+
+    def copy_yoloTxt_TrainVal(self, path_data, total_path, train_val_path, yolo_path):
+        """
+        拷贝jpg、xml文件到相应的train/val数据集文件夹下
+        :param path_data: 被划分的train/val数据集合
+        :param total_path: 总数据文件夹路径
+        :param train_val_path: 划分的train/val数据文件夹路径
+        :return:
+        """
+        yolo_image_path = pathUtil.path_format_join(yolo_path, 'images')
+        pathUtil.mkdir(yolo_image_path)
+        for data in tqdm(path_data, total=len(path_data), ncols=80):
+            xml_path = pathUtil.path_format_join(total_path, str(data + '.xml'))
+            jpg_path = pathUtil.path_format_join(total_path, str(data + '.jpg'))
+            xml_new_path = pathUtil.path_format_join(train_val_path, str(str(xml_path).split('/')[-1]))
+            # jpg_new_path = pathUtil.path_format_join(train_val_path, str(str(jpg_path).split('/')[-1]))
+            jpg_yolo_path = pathUtil.path_format_join(yolo_image_path, str(str(jpg_path).split('/')[-1]))
+            try:
+                shutil.copyfile(xml_path, xml_new_path)
+                # shutil.copyfile(jpg_path, jpg_new_path)
+                shutil.copyfile(jpg_path, jpg_yolo_path)
+            except Exception as e:
+                print(e)
+                continue
+
+    # def xml2_txtCsv(self, data_txt, data_csv, data_dir):
+    #     """
+    #     保存train/val数据信息到txt和csv文件中
+    #     :param data_txt: txt文件
+    #     :param data_csv: csv文件
+    #     :return:
+    #     """
+    #     xml_ = profileUtil.xmlUtil()
+    #     xml_.xml_to_txt(data_txt, data_dir, self.class_index_json)
+    #     xml_.xml_to_csv(data_csv, data_dir)
+
+    def xml2_yolotxtCsv(self, data_txt, data_csv, data_dir):
+        """
+        保存train/val数据信息到txt和csv文件中
+        :param data_txt: txt文件
+        :param data_csv: csv文件
+        :return:
+        """
+        xml_ = profileUtil.xmlUtil()
+        xml_.xml_to_yolo_txt(data_txt, data_dir, self.class_index_json)
+        # xml_.xml_to_csv(data_csv, data_dir)
+
+    def yaml_write(self):
+        """
+        train: /data/humaocheng/sunwin_project/data/kuye/yolo_txt/train_data
+        val: /data/humaocheng/sunwin_project/data/kuye/yolo_txt/val_data
+        # number of classes
+        nc: 6
+        # class names
+        names : ['c2', 'a1', 'b2', 'b1', 'c1', 'a2']
+
+
+        写入yolo_txt/xxx.yaml文件
+        :return:
+        """
+        yaml_train = 'train: ' + self.train_data_yolo_txt
+        yaml_val = 'val: ' + self.val_data_yolo_txt
+        with open(self.class_index_txt, 'r', encoding='utf8') as f:
+            names = [name.strip() for name in f.readlines()]
+        yaml_nc = 'nc: ' + str(len(names))
+        yaml_names = 'names: ' + str(names)
+
+        yaml = pathUtil.path_format_join('/'.join(self.train_data_yolo_txt.split('/')[0:-1]),
+                                         self.project_name + '.yaml')
+        with open(yaml, 'w', encoding='utf8') as f:
+            f.write(yaml_train)
+            f.write('\n')
+            f.write(yaml_val)
+            f.write('\n')
+            f.write(yaml_nc)
+            f.write('\n')
+            f.write(yaml_names)
+            f.close()
+
+    def main(self):
+        # pathUtil.mkdir_new(self.train_dir)
+        # pathUtil.mkdir_new(self.val_dir)
+        # pathUtil.mkdir_new(self.train_data_yolo_txt)
+        # pathUtil.mkdir_new(self.val_data_yolo_txt)
+        train_data, val_data = self.split_train_val()
+        self.copy_yoloTxt_TrainVal(train_data, self.data_dir, self.train_dir, self.train_data_yolo_txt)
+        self.copy_yoloTxt_TrainVal(val_data, self.data_dir, self.val_dir, self.val_data_yolo_txt)
+        self.yaml_write()
+        self.xml2_yolotxtCsv(self.train_data_yolo_txt, self.train_data_csv, self.train_dir)
+        self.xml2_yolotxtCsv(self.val_data_yolo_txt, self.val_data_csv, self.val_dir)
+        print('train_data val_data 分割完成')
+        print('\n-----------------------------step6完成-----------------------------\n')
+
+
+if __name__ == "__main__":
+    splitTrainVal().main()

+ 75 - 0
code/data_manage/gen_data/utils/coorUtil.py

@@ -0,0 +1,75 @@
+"""
+# File       : coorUtil.py
+# Time       :21.6.8 17:28
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:检查标注框在原图片中的位置是否出现超越背景图片总长宽的情况
+"""
+
+import cv2
+
+def check_coor(image_path, row):
+    error_dic = {}
+    width, height, xmin, ymin, xmax, ymax = row["width"], row["height"], row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+    img = cv2.imread(image_path)
+    error = False
+    if type(img) == type(None):
+        error = True
+        code = str('Could not read image')
+        error_dic[image_path] = code
+        return error, error_dic
+
+    org_height, org_width = img.shape[:2]
+
+    if org_width != width:
+        error = True
+        code = str('Width mismatch for image: ' + str(width) + '!=' + str(org_width))
+        error_dic[image_path] = code
+
+    if org_height != height:
+        error = True
+        code = str('Height mismatch for image: ' + str(height) + '!=' + str(org_height))
+        error_dic[image_path] = code
+
+    if xmin > org_width:
+        error = True
+        code = str('XMIN > org_width for file')
+        error_dic[image_path] = code
+
+    if xmin < 0:
+        error = True
+        code = str('XMIN < 0 for file')
+        error_dic[image_path] = code
+
+    if xmax > org_width:
+        error = True
+        code = str('XMAX > org_width for file')
+        error_dic[image_path] = code
+
+    if ymin > org_height:
+        error = True
+        code = str('YMIN > org_height for file')
+        error_dic[image_path] = code
+
+    if ymin < 0:
+        error = True
+        code = str('YMIN < 0 for file')
+        error_dic[image_path] = code
+
+    if ymax > org_height:
+        error = True
+        code = str('YMAX > org_height for file')
+        error_dic[image_path] = code
+
+    if xmin >= xmax:
+        error = True
+        code = str('xmin >= xmax for file')
+        error_dic[image_path] = code
+
+    if ymin >= ymax:
+        error = True
+        code = str('ymin >= ymax for file')
+        error_dic[image_path] = code
+
+    return error, error_dic

+ 99 - 0
code/data_manage/gen_data/utils/fileUtil.py

@@ -0,0 +1,99 @@
+"""
+# File       : fileUtil.py
+# Time       :21.5.25 18:40
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:文件筛选工具类
+"""
+import json
+import time
+import chardet
+import glob
+import pandas as pd
+from gen_data.utils import pathUtil
+
+class ReadFile:
+    """
+    定义读取文件类型
+    label_csv: label_csv下data数据,返回dataframe ['filename', 'filepath', 'width', 'height', 'depth', 'class_name','class_id', 'xmin', 'ymin', 'xmax', 'ymax']
+    label_txt: label_txt下data数据,返回dataframe ['filename', 'class_id', 'xmin', 'ymin', 'xmax', 'ymax']
+    class_json: class_index下label数据,返回class_dict字典 (key: class_name, value: class_id)
+    """
+
+    def read_json_label(self, class_index_json):
+        with open(class_index_json, "r") as f:
+            class_dict = json.load(f)
+        return class_dict
+
+
+    def read_txt_data(self, txt_dir, total_data_txt):
+        txt_list = []
+        with open(pathUtil.path_format_join(txt_dir, total_data_txt), "r") as f:
+            for line in f.readlines():
+                line = line.strip('\n').split(";")
+                filename = line[0].split("/")[1]
+                bbox = line[1:]
+                for member in bbox:
+                    member = member.split(",")
+                    class_id = member[-1]
+                    x_min,y_min,x_max,y_max = member[0],member[1],member[2],member[3]
+                    value = (filename, class_id, x_min, y_min, x_max, y_max)
+                    txt_list.append(value)
+
+        column_name = ["filename", "class_id", "xmin", "ymin", "xmax", "ymax"]
+        txt_df = pd.DataFrame(txt_list, columns=column_name)
+        return txt_df
+
+
+    def read_csv_data(self, csv_dir, total_data_csv):
+        """
+        读取cvs文件数据
+        :param csv_dir:csv文件夹路径
+        :param total_data_csv:csv文件
+        :return:
+        """
+        csv_df = pd.read_csv(pathUtil.path_format_join(csv_dir, total_data_csv), encoding='utf8')
+        return csv_df
+
+
+
+
+
+def extension_filter(base, extension_str):
+    """
+    提取当前目录及子目录下特定格式的文件,并返回其绝对路径
+
+    :param base: 当前目录
+    :param extension_str: 从conf文件中获取的文件扩展名
+    :return: 筛选后得到文件绝对路径的list
+    """
+    extension = extension_str.split(',')
+    fullname_list = []
+    for ex in extension:
+        ex = ex.strip() if ex.strip().startswith('.') else '.' + ex.strip()  # 扩展名补全
+        print(ex)
+        ex_list = glob.glob(base + '/**/*' + ex, recursive=True)
+        print(ex_list)
+        fullname_list.extend(ex_list)
+    return fullname_list
+
+
+def detectCode(path):
+    """
+    获取文本的编码格式
+    :param path:
+    :return:
+    """
+    with open(path, 'rb') as file:
+        data = file.read(1000)
+        dicts = chardet.detect(data)
+    return dicts["encoding"]
+
+def writelog(data):
+    now = time.strftime('%Y-%m-%d %H:%M', time.localtime(time.time()))  # 时间 Y:年份,m:月份,d:日,H:时,M:分
+    with open('../error_log.txt', 'a') as f:  # 打开文件用于追加,若有则覆盖,若无则创建
+        f.write("===============================%s======================================\n"%(now))
+        f.writelines(str(data))
+        f.write("\n")
+

+ 32 - 0
code/data_manage/gen_data/utils/pathUtil.py

@@ -0,0 +1,32 @@
+"""
+# File       : pathUtil.py
+# Time       :21.5.25 18:13
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:路径操作工具
+"""
+import os
+import shutil
+from gen_data.utils import strUtil
+
+def mkdir(new_folder):
+    if not os.path.exists(new_folder):
+        os.makedirs(new_folder)
+
+def mkdir_new(new_folder):
+    if os.path.exists(new_folder):
+        shutil.rmtree(new_folder, ignore_errors=True)
+    os.makedirs(new_folder)
+
+def path_format(path_str):
+    path = strUtil.profile2str(path_str.replace('\\','/'))  # 将字符 '\\' 替换为 '/'
+    if str(path).endswith('/'):
+        return '/'.join(path[0:-1])
+    else:
+        return path
+
+def path_format_join(path_str1, path_str2):
+    return os.path.join(path_format(path_str1), path_format(path_str2)).replace('\\','/')  # 将多个路径组合后返回
+
+

+ 353 - 0
code/data_manage/gen_data/utils/profileUtil.py

@@ -0,0 +1,353 @@
+"""
+# File       : profileUtil.py
+# Time       :21.5.25 16:16
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:配置文件操作工具
+"""
+
+import glob
+import os
+from xml.dom import minidom
+import numpy as np
+import chardet
+import pandas as pd
+import xml.etree.ElementTree as ET
+from configparser import ConfigParser
+from gen_data.utils import fileUtil, pathUtil
+import matplotlib.pyplot as plt
+
+
+class ConfigUtil:
+    """
+    conf文件工具类
+    """
+    def __init__(self):
+        self.ConfPath = r"./config/data_manage.conf"
+        self.video2ImageSection = "vidoe2image"
+        self.data_manageSection = "data_manage"
+        self.database = "database"
+        self.conf = ConfigParser()
+
+    def exists_config(self, conf_path):  # 判断文件是否存在
+        """
+
+        判断配置文件是否存在,如果有section传入,则继续判断conf文件中是否有该section
+
+        :param conf_path: conf文件路径
+        :param section: conf文件中的分区
+        :return: True or False
+        """
+        if os.path.exists(conf_path):  # 判断conf_path是否存在
+            if conf_path.endswith('.conf'):  # 判断字符串是否以指定后缀结尾
+                return True
+            else:
+                print("该文件不是以.conf结尾请检查文件名: %s"%(conf_path))
+                return False
+        else:
+            print("该文件路径不存在,请检查文件路径:   %s"%(conf_path))
+            return False
+
+
+    def get_config_sections(self, conf_path):   # 获取文件
+        """
+        重写config.sections()方法,在原有的基础上增加了对conf文件是否存在以及内容是否为空的判断
+        config.sections() 获取当前conf文件下所有的sections
+        :param conf_path: conf文件路径
+        :return: 列表
+        """
+        if self.exists_config(conf_path):
+            encod = self.detectCode(conf_path)
+            self.conf.read(conf_path, encoding=encod)
+            conf_sections = self.conf.sections()
+            if len(conf_sections) > 0:
+                return conf_sections
+            print("该conf文件内容为空: %s"%(conf_path))
+
+        return []
+
+    def get_config_items(self, conf_path, section):
+        """
+        重写config.items()方法,在原有的基础上增加了对conf文件是否存在以及内容是否为空的判断,将原本返回的list数据类型改为dict类型
+        config.items() 获取当前conf文件下某一个section的所有键值对
+        :param conf_path: conf文件路径
+        :param section: conf文件中的分区
+        :return: 字典
+        """
+        if self.exists_config(conf_path):
+            encod = fileUtil.detectCode(conf_path)
+            self.conf.read(conf_path, encoding=encod)
+            try:
+                conf_items = self.conf.items(section)
+                conf_dict = {key: val for key, val in conf_items}  # 遍历所有的键值对
+                return conf_dict
+            except:    # try(可能出去异常的代码) ... except(如果有异常就执行的代码)
+                print("该conf文件中[%s]分区内容为空:    %s" % (section, conf_path))
+                return {}
+
+        return {}
+
+    def detectCode(self, path):
+        """
+        获取文本的编码格式
+        :param path:
+        :return:
+        """
+        with open(path, 'rb') as file:  # 打开文件,读取二进制文件
+            data = file.read(1000)  # 从文件中读取的字节数,默认-1,读取整个文件
+            dicts = chardet.detect(data)  # chardet文件编码检查,返回一个字典
+        return dicts["encoding"]
+
+
+class xmlUtil:
+    """
+    xml文件工具类
+    """
+
+    def xml_parse(self, data_dir, find='filename'):  # xml读取方法
+        """
+
+        :param data_dir: xml file path
+        :return: dataframe (filename, path, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+        """
+        error_xml_list = []
+        xml_list = []
+        for xml_file in glob.glob(data_dir + '/*.xml'):  # 查找以.xml匹配的文件
+            xml_file_path = pathUtil.path_format(xml_file)
+            #print(xml_file_path, '111111111111111111111111111111111111111')
+            try:  # try...except 是用来处理异常
+                tree = ET.parse(xml_file_path)  # 读取xml文件
+            except:
+                error_xml_list.append(xml_file_path)
+                continue
+            root = tree.getroot()  # 获取根节点
+            filename = root.find(find).text
+            width, height, depth = int(root.find('size')[0].text), int(root.find('size')[1].text), int(
+                root.find('size')[2].text)
+            for member in root.findall('object'):  # 非递归查找所有子节点
+                class_name = member[0].text.upper()
+                x_min, y_min, x_max, y_max = int(member[4][0].text), int(member[4][1].text), int(
+                    member[4][2].text), int(member[4][3].text)
+                value = (filename, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+                xml_list.append(value)
+
+
+        column_name = ['filename', 'width', 'height','depth' ,'class_name', 'xmin', 'ymin', 'xmax', 'ymax']
+        xml_df = pd.DataFrame(xml_list, columns=column_name)  # pandas数据类型 DataFrame,创建
+        if error_xml_list:
+            fileUtil.writelog(error_xml_list)
+            print('解析错误的xml文件有:%s'%(error_xml_list))
+        return xml_df
+
+    def xml_to_csv(self, total_data_csv, path, data_distribution=None):
+        """
+        将xml转换成csv文件
+
+        :return:
+        """
+        xml_df = self.xml_parse(path)
+        if data_distribution is None:
+            xml_df.to_csv(total_data_csv, index=False)
+            print("%s 写入 %d 行" % (total_data_csv, xml_df.shape[0]))
+        else:
+            df_groupby = xml_df.groupby('class_name')
+            all_class_dict = {}
+            for k, v in df_groupby:
+                all_class_dict[k] = len(v)
+            plt.figure(figsize=(40, 40))  # 绘制
+            label = [key for key in all_class_dict.keys()]
+            value = [val for val in all_class_dict.values()]
+            plt.pie(value, labels=label, autopct=lambda x: self.my_label(x, value))  # 绘制饼图
+            # value是数据,label是标签,autopck是数据标签
+            plt.savefig(data_distribution)
+            plt.show()
+            xml_df.to_csv(total_data_csv, index=False)
+            print("%s 写入 %d 行"%(total_data_csv, xml_df.shape[0]))
+
+    def xml_to_txt(self, total_data_txt, path, class_index_json):
+        """
+        # 将xml文件转换为txt文件,只需要知道图片的名字和'object'的坐标(四个点坐标)即可
+        将xml文件转换成txt文件
+        arrange label file from xml files, which is formatted as txt. For Example:
+
+        image_full_path [space] [x_min, y_min, x_max, y_max, class_index [space]],
+
+        Like:
+        /data/object_detection/keras-yolo3-master/data/label_data/K2_112647_540.jpg 456,1,516,104,4 662,1,708,102,4 457,229,519,403,4 664,231,711,397,4 852,227,917,401,4 1038,223,1121,396,4 1199,204,1280,417,9
+
+
+        @param data_dir: the folder is to save images and annotations
+        @param txt_dir: the path for saving annotations
+        @param class_index_json: this param is dictionary and key is class name, value is class index
+        # class_index_json这个参数是字典,键是类的名字,键值是类索引
+        @return:
+        """
+        label_content_list = []
+        class_dict = fileUtil.ReadFile().read_json_label(class_index_json)
+
+        xml_df = self.xml_parse(path)
+        group_df = xml_df.groupby('filename')
+        for k, values in group_df:
+            value = ''
+            for id, v in values.iterrows():  # 对表格进行遍历,返回元祖(k,v)
+                class_name = str(class_dict[v['class_name']])
+                ymax, xmax, ymin, xmin = str(v['ymax']), str(v['xmax']), str(v['ymin']), str(v['xmin'])
+                single = ','.join([xmin, ymin, xmax, ymax, class_name])
+                value ="{};{}".format(value, single)
+            single_obj_str = "{}{}".format(k, value)
+            label_content_list.append(single_obj_str)
+
+        # write total_label.txt
+        count = 0
+        with open(total_data_txt, "w") as label_txt:
+            # enumerate多用于在for循环中得到计数,利用它可以同时获得索引和值,即需要index和value值的时候可以使用enumerate
+            for index, label_content in enumerate(label_content_list):  # 用于将一个可遍历数据对象(列表、元组或者字符串)组合为一个索引序列,同时列出数据和数据下标,通常在for循环中
+                label_txt.write(label_content)
+                if index < len(label_content_list) - 1:
+                    label_txt.write("\n")
+                count += 1
+        print("%s 写入 %d 行 "%(total_data_txt, count))
+
+    def yolo5_normalized(self, df):
+        # str(v['ymax']), str(v['xmax']), str(v['ymin']), str(v['xmin'])/ df['height']
+        x = abs(df['xmax'] - df['xmin'])
+        y = abs(df['ymax'] - df['ymin'])
+        x_center = (x / 2 + df['xmin']) / df['width']
+        y_center = (y / 2 + df['ymin']) / df['height']
+        width = x / df['width']
+        height = y / df['height']
+        #  yolo_txt格式:class_id, x_center, y_center, width, height  (相对值)
+        return x_center, y_center, width, height
+
+    def xml_to_yolo_txt(self, yolo_txt, path, class_index_json):
+        """
+        将xml文件转换成yolo|_txt文件
+        arrange label file from xml files, which is formatted as txt. For Example:
+
+        image_full_path [space] [x_min, y_min, x_max, y_max, class_index [space]],
+
+        Like:
+        /data/object_detection/keras-yolo3-master/data/label_data/K2_112647_540.jpg 456,1,516,104,4 662,1,708,102,4 457,229,519,403,4 664,231,711,397,4 852,227,917,401,4 1038,223,1121,396,4 1199,204,1280,417,9
+
+
+        @param data_dir: the folder is to save images and annotations
+        @param txt_dir: the path for saving annotations
+        @param class_index_json: this param is dictionary and key is class name, value is class index
+        @return:
+        """
+        class_dict = fileUtil.ReadFile().read_json_label(class_index_json)
+
+        xml_df = self.xml_parse(path)
+        group_df = xml_df.groupby('filename')
+        count = 0
+        for k, values in group_df:
+            label_content_list = []
+            for id, v in values.iterrows():
+                class_name = str(class_dict[v['class_name']])
+                x_center, y_center, width, height = self.yolo5_normalized(v)
+                single = ' '.join([class_name, str(x_center), str(y_center), str(width), str(height)])
+                label_content_list.extend([single])
+
+            file_ = '.'.join(str(k).split('.')[0:-1])
+            file_name = pathUtil.path_format(file_ + '.txt')
+            yolo_txt_path = pathUtil.path_format_join(yolo_txt, 'labels')
+            pathUtil.mkdir(yolo_txt_path)
+            txt_path = pathUtil.path_format_join(yolo_txt_path, file_name)
+
+            with open(txt_path, "w+") as label_txt:
+                for index, label_content in enumerate(label_content_list):
+                    label_txt.write(label_content)
+                    if index < len(label_content_list) - 1:
+                        label_txt.write("\n")
+            count += 1
+        print("%s 写入 %d 个txt文件 "%(yolo_txt, count))
+    @staticmethod
+    def write_xml(img, df_t, whd_list, total_data_dir):
+        """
+        生成xml文件,写入数据后保存
+
+        :param img:添加过特征图像后新图片的路径
+        :param df_t:该图片中特征图像的坐标信息
+        :param whd_list:新图片的长宽信息
+        :param total_data_dir:最终保存xml数据的路径
+        :return:
+        """
+        filename = img.split('/')[-1]
+
+        # 1. 创建dom树对象
+        doc = minidom.Document()
+
+        # 2. 创建根结点,并用dom对象添加根结点
+        root_node = doc.createElement("annotation")
+        doc.appendChild(root_node)
+
+        # 3. 创建结点,结点包含一个文本结点, 再将结点加入到根结点
+        folder_node = doc.createElement("folder")
+        folder_value = doc.createTextNode('ZS')
+        folder_node.appendChild(folder_value)
+        root_node.appendChild(folder_node)
+
+        filename_node = doc.createElement("filename")
+        filename_value = doc.createTextNode(filename)
+        filename_node.appendChild(filename_value)
+        root_node.appendChild(filename_node)
+
+        path_node = doc.createElement("path")
+        path_value = doc.createTextNode(img)
+        path_node.appendChild(path_value)
+        root_node.appendChild(path_node)
+
+        source_node = doc.createElement("source")
+        database_node = doc.createElement("database")
+        database_node.appendChild(doc.createTextNode("Unknown"))
+        source_node.appendChild(database_node)
+        root_node.appendChild(source_node)
+
+        size_node = doc.createElement("size")
+        for item, value in zip(["width", "height", "depth"], whd_list):
+            elem = doc.createElement(item)
+            elem.appendChild(doc.createTextNode(str(value)))
+            size_node.appendChild(elem)
+        root_node.appendChild(size_node)
+
+        seg_node = doc.createElement("segmented")
+        seg_node.appendChild(doc.createTextNode(str(0)))
+        root_node.appendChild(seg_node)
+
+        for _, df in df_t.iterrows():
+            obj_node = doc.createElement("object")
+            name_node = doc.createElement("name")
+            name_node.appendChild(doc.createTextNode(str(df['class'])))
+            obj_node.appendChild(name_node)
+
+            pose_node = doc.createElement("pose")
+            pose_node.appendChild(doc.createTextNode("Unspecified"))
+            obj_node.appendChild(pose_node)
+
+            trun_node = doc.createElement("truncated")
+            trun_node.appendChild(doc.createTextNode(str(0)))
+            obj_node.appendChild(trun_node)
+
+            trun_node = doc.createElement("difficult")
+            trun_node.appendChild(doc.createTextNode(str(0)))
+            obj_node.appendChild(trun_node)
+
+            bndbox_node = doc.createElement("bndbox")
+            for item, value in zip(["xmin", "ymin", "xmax", "ymax"], [df['xmin'], df['ymin'], df['xmax'], df['ymax']]):
+                elem = doc.createElement(item)
+                elem.appendChild(doc.createTextNode(str(value)))
+                bndbox_node.appendChild(elem)
+            obj_node.appendChild(bndbox_node)
+            root_node.appendChild(obj_node)
+
+        xml_file = filename.split('.')[0] + '.xml'
+        with open(pathUtil.path_format_join(total_data_dir, xml_file), "w", encoding="utf-8") as f:
+            # 4.writexml()第一个参数是目标文件对象,第二个参数是根节点的缩进格式,第三个参数是其他子节点的缩进格式,
+            # 第四个参数制定了换行格式,第五个参数制定了xml内容的编码。
+            doc.writexml(f, indent='', addindent='\t', newl='\n', encoding="utf-8")
+
+    @staticmethod
+    def my_label(pct, allvals):
+        absolute = int(pct / 100. * np.sum(allvals))
+        return "{:.1f}%\n({:d} )".format(pct, absolute)

+ 39 - 0
code/data_manage/gen_data/utils/strUtil.py

@@ -0,0 +1,39 @@
+"""
+# File       : strUtil.py
+# Time       :21.5.25 18:50
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:字符串工具类
+"""
+import sys
+
+#
+# def str2list(strs, split_str = ','):
+#     strs = ''.join(x for x in strs if x.isprintable())
+#     str_list = strs.strip().split(split_str)
+def profile2str(st):
+    return st.replace('"', '').replace('\'', '')
+
+
+
+def is_num(n, key):
+    try:
+        num = int(n)
+        return num
+    except:
+        sys.exit('请确认%s字段的值是数字'%(key))
+
+
+def is_absnum(n, key):
+    num = is_num(n, key)
+    return abs(num)
+
+def num_str(st):
+    try:
+        num = int(st)
+        return num
+    except:
+        return st
+
+

+ 151 - 0
code/data_manage/gen_data/video2image_step1.py

@@ -0,0 +1,151 @@
+"""
+# File       : video2image_step1.py
+# Time       :21.5.25 15:42
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:视频转换为图片 step1_视频切分为图片
+"""
+import sys
+import cv2
+import time
+import math
+import os
+from tqdm import tqdm
+from gen_data.utils import profileUtil, pathUtil, strUtil, fileUtil
+
+
+class Video2Image:
+    def __init__(self):
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.video2ImageSection)
+        self.video_dir = pathUtil.path_format(all_items['video_dir']) # 视频地址
+        self.save_dir = pathUtil.path_format(all_items['save_dir'])# 保存图片地址
+        self.vidoes_extension = pathUtil.path_format(all_items['vidoes_extension'])# 视频后缀名
+        self.file_name = pathUtil.path_format(all_items['file_name'])# 图片文件前缀
+        self.all_image_num = strUtil.is_absnum(all_items['all_image_num'], 'all_image_num')# 图片保存总个数
+        self.frame_num = strUtil.is_absnum(all_items['frame_num'], 'frame_num') # 帧数
+
+    def get_frame_split(self, frames: int, all_image_num: int):
+        """
+        当frame_num为0,即给出需要切分的图片总个数时,该方法会根据视频的总帧数除以图片总数来计算出相应的切分间隔帧数。
+
+        :param frames: 视频总帧数
+        :param all_image_num: 需要切分多少图片
+        :return: 切分间隔帧数
+        """
+        frame_split = math.ceil(frames / all_image_num)
+        if frame_split < 1:
+            sys.exit('%s//%s小于1,视频无法被正常切分。请检查all_image_num字段或者视频的总帧数是否太小'%(frames, all_image_num))
+
+        return frame_split
+
+
+    def gen_images(self, frame_split, vc, video_name):
+        """
+        视频切分图片,并保存在相应的文件夹中,保存图像文件的命名规则:配置前缀名+视频文件名+时间+帧数+'.jpg'
+
+        :param frame_split:需要间隔的帧数
+        :param vc:cv对象
+        :param video_name:视频文件名
+        :return:
+        """
+        time_str = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
+        video = '.'.join(video_name.split(".")[0:-1])
+
+        rval, _ = vc.read() if vc.isOpened() else False
+        c, count = 0, 0
+        str_len = len(str(int(vc.get(7))))
+        for c in tqdm(range(0,int(vc.get(7))), total=int(vc.get(7)), ncols=70):
+            if c % frame_split == 0:
+                rval, frame = vc.read()
+                # 每隔frame_split帧提取一张照片
+                if rval:
+                    image_name = self.file_name + '_' + video + '_' + time_str + '_' + str(c).zfill(str_len) + '.jpg'
+                    image_save_name = self.save_dir + "/" + image_name
+                    try:
+                        if os.path.exists(image_save_name):
+                            print('文件已存在,正在覆盖保存')
+                        cv2.imwrite(image_save_name, frame)
+                        count += 1
+                    except:
+                        continue
+        print('%s文件总共保存了%d张图片'%(video_name, count))
+        return count
+
+    def gen_images_2(self, frame_split, vc, video_name):
+        """
+        针对CV无法读出视频总帧数的视频进行切分
+        视频切分图片,并保存在相应的文件夹中,保存图像文件的命名规则:配置前缀名+视频文件名+时间+帧数+'.jpg'
+
+        :param frame_split:需要间隔的帧数
+        :param vc:cv对象
+        :param video_name:视频文件名
+        :return:
+        """
+        time_str = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
+        video = '.'.join(video_name.split('.')[0:-1])
+
+        # rval, _ = vc.read() if vc.isOpened() else False
+        frame_count, count = 0, 0
+
+        with tqdm(total=5000) as pbar:
+            while (True):
+                ret, frame = vc.read()
+                if ret is False:
+                    break
+                else:
+                    frame_count = frame_count + 1
+                    if frame_count % frame_split == 0:
+                        image_name = self.file_name + '_' + video + '_' + time_str + '_' + str(frame_count) + '.jpg'
+                        image_save_name = self.save_dir + "/" + image_name
+                        pbar.update(1)
+                        pbar.set_description("切分第 %d 张图片" % count)
+                        try:
+                            if os.path.exists(image_save_name):
+                                print('文件已存在,正在覆盖保存')
+                            cv2.imwrite(image_save_name, frame)
+                            count += 1
+                        except:
+                            continue
+        print('%s文件总共保存了%d张图片' % (video_name, count))
+        return count
+
+    def video_split(self, video_path):
+        """
+        获取视频文件的帧数,根据配置文件内容来判断用all_image_num还是frame_num来切分视频,然后传入gen_images()方法中
+        :param video_path:视频地址
+        :return:2w
+        """
+        video_name = video_path.split("\\")[-1].split("/")[-1]
+        vc = cv2.VideoCapture(video_path)
+        frames = int(vc.get(7))
+        print("%s 视频总帧数 %d,裁剪图片中..." % (video_name, frames))
+        # 判断使用all_image_num还是frame_num作为切分视频的变量
+        frame_split = self.get_frame_split(frames, self.all_image_num) if self.frame_num < 1 else self.frame_num
+        if frames < 0:
+            count = self.gen_images_2(frame_split, vc, video_name)
+        else:
+            count = self.gen_images(frame_split, vc, video_name)
+        return count
+
+    def video_process(self):
+        """
+        遍历目录下所以符合条件的视频文件,然后传入video_split()方法
+        :return:
+        """
+        vidoes_count = 0
+        images_count = 0
+        video_path_list = fileUtil.extension_filter(self.video_dir, self.vidoes_extension)
+        for video_path in video_path_list:
+            images_count += self.video_split(video_path)
+            vidoes_count += 1
+        print('成功读取了%d个视频, 截取图片%d张'%(vidoes_count, images_count))
+
+    def main(self):
+        pathUtil.mkdir_new(self.save_dir)
+        self.video_process()
+        print('\n-----------------------------step1完成-----------------------------\n')
+
+if __name__ == "__main__":
+    Video2Image().main()

+ 16 - 0
code/data_manage/main.py

@@ -0,0 +1,16 @@
+# This is a sample Python script.
+
+# Press Shift+F10 to execute it or replace it with your code.
+# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
+
+
+def print_hi(name):
+    # Use a breakpoint in the code line below to debug your script.
+    print(f'Hi, {name}')  # Press Ctrl+F8 to toggle the breakpoint.
+
+
+# Press the green button in the gutter to run the script.
+if __name__ == '__main__':
+    print_hi('PyCharm')
+
+# See PyCharm help at https://www.jetbrains.com/help/pycharm/

+ 57 - 0
code/data_manage/run_gen.py

@@ -0,0 +1,57 @@
+"""
+# File       : run_gen.py
+# Time       :21.6.7 17:36
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import sys
+import shutil
+import os
+import os.path
+
+from gen_data import video2image_step1, check_data_step2, classFeatureImages_step3, gen_class_index_step4, genAnn_step5, \
+    splitTrainVal_step6, class_operate
+# from gen_data import video2image_step1, check_data_step2, classFeatureImages_step3, gen_class_index_step4, genAnn_step5, \
+#     splitTrainVal_step6, gen_pb_tfrecord_step7, class_operate
+from test_util.other2jpg import other2jpg   # 其他类型图片转换为jpg格式,支持36类
+# from database import insert_tables_data, delete_tables_data, update_tables_data, create_db  # 生成数据库及相关数据操作
+
+# from database.utils import delete_data_util, insert_data_util, connect_db_util, select_data_util  # 工具类
+from gen_data.utils import profileUtil, pathUtil, fileUtil  # 配置文件操作、路径操作、文件筛选工具
+
+
+def main():
+    step2 = check_data_step2.checkFile()
+    project_path = step2.project_path
+    total_path = step2.file_names
+    step2.main()
+    classFeatureImages_step3.classFeatureImages().main()
+    gen_class_index_step4.genClassIndex().main()
+    genAnn_step5.ganAnn().main()
+    splitTrainVal_step6.splitTrainVal().main()
+
+    def DeleteFiles(path, remainDirsList, filesList):
+        dirsList = []
+        dirsList = os.listdir(path)
+        global filepath
+        for f in dirsList:
+            if f not in remainDirsList:
+                filepath = os.path.join(path, f)
+                if os.path.isdir(filepath):
+                    shutil.rmtree(filepath, True)
+
+    if __name__ == "__main__":
+        conf = profileUtil.ConfigUtil()
+        all_items = conf.get_config_items(conf.ConfPath, conf.data_manageSection)
+        path = pathUtil.path_format_join(all_items["dir"], all_items["dir"])
+        filesList = ['yolo','total_data','redundant_data']
+        # 当前目录中需要保留的文件夹
+        dirsList = ['data_distribution.jpg', 'classes.txt']
+        # 当前目录需要保留的文件
+        DeleteFiles(path, filesList, dirsList)
+
+
+if __name__ == "__main__":
+    main()

+ 177 - 0
code/data_manage/test_util/Qt5/checkfile/Form.py

@@ -0,0 +1,177 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'Form.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.4
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again.  Do not edit this file unless you know what you are doing.
+
+
+import os
+import sys
+import cv2
+import numpy as np
+
+from PyQt5.QtWidgets import QFileDialog, QMessageBox
+from PyQt5 import QtCore, QtWidgets
+from utils import check_data_step2, classFeatureImages_step3
+
+
+class Ui_Form(QtWidgets.QMainWindow):
+    def __init__(self):
+        super(Ui_Form, self).__init__()
+        self.setupUi(self)
+        self.retranslateUi(self)
+        self.cwd = os.getcwd()
+
+    def setupUi(self, Form):
+        Form.setObjectName("Form")
+        Form.resize(650, 400)
+        self.widget = QtWidgets.QWidget(Form)
+        self.widget.setGeometry(QtCore.QRect(50, 12, 562, 330))
+        self.widget.setObjectName("widget")
+        self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.widget)
+        self.verticalLayout_3.setContentsMargins(0, 0, 0, 0)
+        self.verticalLayout_3.setObjectName("verticalLayout_3")
+        self.verticalLayout_2 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_2.setObjectName("verticalLayout_2")
+        self.horizontalLayout = QtWidgets.QHBoxLayout()
+        self.horizontalLayout.setObjectName("horizontalLayout")
+        self.label = QtWidgets.QLabel(self.widget)
+        self.label.setObjectName("label")
+        self.horizontalLayout.addWidget(self.label)
+        self.lineEdit = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit.setObjectName("lineEdit")
+        self.horizontalLayout.addWidget(self.lineEdit)
+        self.pushButton = QtWidgets.QPushButton(self.widget)
+        self.pushButton.setObjectName("pushButton")
+        self.horizontalLayout.addWidget(self.pushButton)
+        self.verticalLayout_2.addLayout(self.horizontalLayout)
+        self.pushButton_2 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_2.setObjectName("pushButton_2")
+        self.verticalLayout_2.addWidget(self.pushButton_2)
+        self.verticalLayout_3.addLayout(self.verticalLayout_2)
+        self.verticalLayout = QtWidgets.QVBoxLayout()
+        self.verticalLayout.setObjectName("verticalLayout")
+        self.label_2 = QtWidgets.QLabel(self.widget)
+        self.label_2.setObjectName("label_2")
+        self.verticalLayout.addWidget(self.label_2)
+        self.label_3 = QtWidgets.QLabel(self.widget)
+        self.label_3.setObjectName("label_3")
+        self.verticalLayout.addWidget(self.label_3)
+        self.label_4 = QtWidgets.QLabel(self.widget)
+        self.label_4.setObjectName("label_4")
+        self.verticalLayout.addWidget(self.label_4)
+        self.label_5 = QtWidgets.QLabel(self.widget)
+        self.label_5.setObjectName("label_5")
+        self.verticalLayout.addWidget(self.label_5)
+        self.label_6 = QtWidgets.QLabel(self.widget)
+        self.label_6.setObjectName("label_6")
+        self.verticalLayout.addWidget(self.label_6)
+        self.label_7 = QtWidgets.QLabel(self.widget)
+        self.label_7.setObjectName("label_7")
+        self.verticalLayout.addWidget(self.label_7)
+        self.label_9 = QtWidgets.QLabel(self.widget)
+        self.label_9.setObjectName("label_9")
+        self.verticalLayout.addWidget(self.label_9)
+        self.label_8 = QtWidgets.QLabel(self.widget)
+        self.label_8.setObjectName("label_8")
+        self.verticalLayout.addWidget(self.label_8)
+        self.label_10 = QtWidgets.QLabel(self.widget)
+        self.label_10.setObjectName("label_10")
+        self.verticalLayout.addWidget(self.label_10)
+        self.label_11 = QtWidgets.QLabel(self.widget)
+        self.label_11.setObjectName("label_11")
+        self.verticalLayout.addWidget(self.label_11)
+        self.label_12 = QtWidgets.QLabel(self.widget)
+        self.label_12.setObjectName("label_12")
+        self.verticalLayout.addWidget(self.label_12)
+        self.label_13 = QtWidgets.QLabel(self.widget)
+        self.label_13.setObjectName("label_13")
+        self.verticalLayout.addWidget(self.label_13)
+        self.label_14 = QtWidgets.QLabel(self.widget)
+        self.label_14.setObjectName("label_14")
+        self.verticalLayout.addWidget(self.label_14)
+        self.label_15 = QtWidgets.QLabel(self.widget)
+        self.label_15.setObjectName("label_15")
+        self.verticalLayout.addWidget(self.label_15)
+        self.label_16 = QtWidgets.QLabel(self.widget)
+        self.label_16.setObjectName("label_16")
+        self.verticalLayout.addWidget(self.label_16)
+        self.verticalLayout_3.addLayout(self.verticalLayout)
+
+        self.pushButton.clicked.connect(self.openfolder_image_path)
+        self.pushButton_2.clicked.connect(self.button_click)
+
+        self.retranslateUi(Form)
+        QtCore.QMetaObject.connectSlotsByName(Form)
+
+    def retranslateUi(self, Form):
+        _translate = QtCore.QCoreApplication.translate
+        Form.setWindowTitle(_translate("Form", "质检工具"))
+        self.label.setText(_translate("Form", "标注图片所在文件夹:"))
+        self.pushButton.setText(_translate("Form", "选择文件夹"))
+        self.pushButton_2.setText(_translate("Form", "开始质检"))
+        self.label_2.setText(_translate("Form", "使用说明:"))
+        self.label_3.setText(_translate("Form", "1.选择标注图片及xml文件所在文件夹"))
+        self.label_4.setText(_translate("Form", "2.点击开始质检按钮,即可开始质检"))
+        self.label_5.setText(_translate("Form", "3.质检完成后去查看标注图片目录下的两个文件夹"))
+        self.label_6.setText(_translate("Form", "3.1 错误的标注文件会存放在redundant_data目录下,包括了缺失文件或者标注框越界"))
+        self.label_7.setText(_translate("Form", "3.2 切分后的标注框图像会存放在class_img_dir目录下,每一个类别对应一个文件夹"))
+        self.label_9.setText(_translate("Form", "4.质检"))
+        self.label_8.setText(_translate("Form", "4.1 质检的时候需要观察每个类别目录文件夹下的图片与标注图示是否相同"))
+        self.label_10.setText(_translate("Form", "4.2 观察标注的图片是否过于模糊,无法用肉眼辨别的需要删除"))
+        self.label_11.setText(_translate("Form", "4.3 除非特殊说明,否则标注目标残缺大于50%的图片需要删除"))
+        self.label_12.setText(_translate("Form", "5.质检后的修改"))
+        self.label_13.setText(_translate("Form", "5.1 切分的标注图片命名规则:原图片名称_xmin_ymin_xmax_y_max.jpg(xmin,xmax...为标注框位置)"))
+        self.label_14.setText(_translate("Form", "5.2 如果需要修改xml文件,则需要通过上面的命名规则去定位xml文件位置"))
+        self.label_15.setText(_translate("Form", "5.3 例如:16258961220000-10304-1626535777-hd_4_112_199_446.jpg标注文件有问题,需要修改xml文件"))
+        self.label_16.setText(_translate("Form", "5.4 原图片文件及xml文件名为:16258961220000-10304-1626535777-hd.xxx"))
+
+
+
+    def openfolder_image_path(self):
+        """
+        选择视频文件夹路径
+        :return:
+        """
+
+        openfolder_path = QFileDialog.getExistingDirectory(self, '打开文件夹', self.cwd)
+        self.lineEdit.setText(openfolder_path)
+
+    def show_message_succes(self, len_df):
+        """
+        消息提示框
+        :return:
+        """
+        succes = "图片质检完成!总标注框数为:%s" % len_df
+        QMessageBox.about(self, "信息提示", succes)
+
+    def show_message_folder(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "空值提示", "没有获取到相应的文件夹路径!")
+
+    def button_click(self):
+        if self.lineEdit.text():
+            path = self.lineEdit.text()
+            self.check_images(path)
+        else:
+            self.show_message_folder()
+
+    def check_images(self, path):
+        check_data_step2.checkFile(path).main()
+        len_df = classFeatureImages_step3.classFeatureImages(path).main()
+        self.show_message_succes(len_df)
+
+
+if __name__ == "__main__":
+    app = QtWidgets.QApplication(sys.argv)
+    QMainWindow = QtWidgets.QMainWindow()
+    ui = Ui_Form()
+    ui.setupUi(QMainWindow)
+    QMainWindow.show()
+    sys.exit(app.exec_())

+ 44 - 0
code/data_manage/test_util/Qt5/checkfile/Form.spec

@@ -0,0 +1,44 @@
+# -*- mode: python ; coding: utf-8 -*-
+
+
+block_cipher = None
+
+
+a = Analysis(
+    ['Form.py'],
+    pathex=[],
+    binaries=[],
+    datas=[],
+    hiddenimports=[],
+    hookspath=[],
+    hooksconfig={},
+    runtime_hooks=[],
+    excludes=[],
+    win_no_prefer_redirects=False,
+    win_private_assemblies=False,
+    cipher=block_cipher,
+    noarchive=False,
+)
+pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
+
+exe = EXE(
+    pyz,
+    a.scripts,
+    a.binaries,
+    a.zipfiles,
+    a.datas,
+    [],
+    name='Form',
+    debug=False,
+    bootloader_ignore_signals=False,
+    strip=False,
+    upx=True,
+    upx_exclude=[],
+    runtime_tmpdir=None,
+    console=True,
+    disable_windowed_traceback=False,
+    argv_emulation=False,
+    target_arch=None,
+    codesign_identity=None,
+    entitlements_file=None,
+)

+ 172 - 0
code/data_manage/test_util/Qt5/checkfile/Form.ui

@@ -0,0 +1,172 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<ui version="4.0">
+ <class>Form</class>
+ <widget class="QWidget" name="Form">
+  <property name="geometry">
+   <rect>
+    <x>0</x>
+    <y>0</y>
+    <width>847</width>
+    <height>642</height>
+   </rect>
+  </property>
+  <property name="windowTitle">
+   <string>Form</string>
+  </property>
+  <widget class="QWidget" name="">
+   <property name="geometry">
+    <rect>
+     <x>50</x>
+     <y>12</y>
+     <width>562</width>
+     <height>330</height>
+    </rect>
+   </property>
+   <layout class="QVBoxLayout" name="verticalLayout_3">
+    <item>
+     <layout class="QVBoxLayout" name="verticalLayout_2">
+      <item>
+       <layout class="QHBoxLayout" name="horizontalLayout">
+        <item>
+         <widget class="QLabel" name="label">
+          <property name="text">
+           <string>标注图片所在文件夹:</string>
+          </property>
+         </widget>
+        </item>
+        <item>
+         <widget class="QLineEdit" name="lineEdit"/>
+        </item>
+        <item>
+         <widget class="QPushButton" name="pushButton">
+          <property name="text">
+           <string>选择文件夹</string>
+          </property>
+         </widget>
+        </item>
+       </layout>
+      </item>
+      <item>
+       <widget class="QPushButton" name="pushButton_2">
+        <property name="text">
+         <string>开始质检</string>
+        </property>
+       </widget>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <layout class="QVBoxLayout" name="verticalLayout">
+      <item>
+       <widget class="QLabel" name="label_2">
+        <property name="text">
+         <string>使用说明:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_3">
+        <property name="text">
+         <string>1.选择标注图片及xml文件所在文件夹</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_4">
+        <property name="text">
+         <string>2.点击开始质检按钮,即可开始质检</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_5">
+        <property name="text">
+         <string>3.质检完成后去查看标注图片目录下的两个文件夹</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_6">
+        <property name="text">
+         <string>3.1 错误的标注文件会存放在redundant_data目录下,包括了缺失文件或者标注框越界</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_7">
+        <property name="text">
+         <string>3.2 切分后的标注框图像会存放在class_img_dir目录下,每一个类别对应一个文件夹</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_9">
+        <property name="text">
+         <string>4.质检</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_8">
+        <property name="text">
+         <string>4.1 质检的时候需要观察每个类别目录文件夹下的图片与标注图示是否相同</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_10">
+        <property name="text">
+         <string>4.2 观察标注的图片是否过于模糊,无法用肉眼辨别的需要删除</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_11">
+        <property name="text">
+         <string>4.3 除非特殊说明,否则标注目标残缺大于50%的图片需要删除</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_12">
+        <property name="text">
+         <string>5.质检后的修改</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_13">
+        <property name="text">
+         <string>5.1 切分的标注图片命名规则:原图片名称_xmin_ymin_xmax_y_max.jpg(xmin,xmax...为标注框位置)</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_14">
+        <property name="text">
+         <string>5.2 如果需要修改xml文件,则需要通过上面的命名规则去定位xml文件位置</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_15">
+        <property name="text">
+         <string>5.3 例如:16258961220000-10304-1626535777-hd_4_112_199_446.jpg标注文件有问题,需要修改xml文件</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_16">
+        <property name="text">
+         <string>5.4 原图片文件及xml文件名为:16258961220000-10304-1626535777-hd.xxx</string>
+        </property>
+       </widget>
+      </item>
+     </layout>
+    </item>
+   </layout>
+  </widget>
+ </widget>
+ <resources/>
+ <connections/>
+</ui>

File diff suppressed because it is too large
+ 1 - 0
code/data_manage/test_util/Qt5/checkfile/error_log.txt


+ 8 - 0
code/data_manage/test_util/Qt5/checkfile/utils/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py
+# Time       :21.7.28 17:43
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 87 - 0
code/data_manage/test_util/Qt5/checkfile/utils/check_data_step2.py

@@ -0,0 +1,87 @@
+"""
+# File       : check_data_step2.py
+# Time       :21.5.29 12:26
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:对标注公司标注好的文件进行二次质检,确保xml文件和jpg文件一一对应 step2_检查是否存在漏标的图片
+"""
+import glob
+import shutil
+import sys
+
+from utils import pathUtil, fileUtil
+
+class checkFile:
+    def __init__(self, path):
+        self.project_path = pathUtil.path_format(path)
+        self.redundant_dir = pathUtil.path_format_join(path, "redundant_data") # 问题xml文件和jpg文件保存路径
+        self.file_names = self.project_path  # 图片保存路径
+
+    def diff_check(self, list1, list2, file_type):
+        """
+        检查list1中有而list2中没有的文件,将问题文件移动到redundant_dir文件夹中,并返回问题文件的绝对路径
+        :param list1:列表
+        :param list2:列表
+        :param file_type:文件类型
+        :return:由问题文件路径组成的列表
+        """
+        problem_list = []
+        diff_list = set(list1).difference(list2)
+        for diff in diff_list:
+            pronlem_file_name = diff+file_type
+            pronlem_file_path = pathUtil.path_format_join(self.file_names, pronlem_file_name)
+            move_file_path = pathUtil.path_format_join(self.redundant_dir, pronlem_file_name)
+            problem_list.append(pronlem_file_path)
+            shutil.move(pronlem_file_path, move_file_path)
+        if len(problem_list)>0:
+            if file_type == '.xml':
+                problem_list.extend([{'.xml文件缺少对应的.jpg文件': problem_list}])
+                print('这些.xml文件缺少对应的.jpg文件:%s'% (problem_list))
+            else:
+                problem_list.extend([{'.jpg文件缺少对应的.xml文件': problem_list}])
+                print('这些.jpg文件缺少对应的.xml文件:%s' % (problem_list))
+        return problem_list
+
+
+    def check_file(self, xml_name_list, jpg_name_list):
+        """
+        筛选出xml和jpg无法一一对应的问题文件路径。并组成列表并返回
+        :param xml_name_list: xml文件的列表
+        :param jpg_name_list: jpg文件的列表
+        :return: 问题文件列表
+        """
+        diff1 = self.diff_check(xml_name_list, jpg_name_list, '.xml')
+        diff2 = self.diff_check(jpg_name_list, xml_name_list, '.jpg')
+        problem_list = diff1 + diff2
+        return problem_list
+
+    def main(self):
+
+        xml_name_list = [pathUtil.path_format(file).split('/')[-1].split('.xml')[0] for file in glob.glob(self.file_names + '/*.xml')]
+        jpg_name_list = [pathUtil.path_format(file).split('/')[-1].split('.jpg')[0] for file in glob.glob(self.file_names + '/*.jpg')]
+        if len(xml_name_list)+len(jpg_name_list) < 1:
+            print('没有找相应的数据,请检查 %s 路径数据'% self.file_names)
+            sys.exit(-1)
+        pathUtil.mkdir_new(self.redundant_dir)
+        problem_list = self.check_file(xml_name_list, jpg_name_list)
+        if problem_list:
+            fileUtil.writelog(problem_list)
+            print('问题文件的存放地址为:%s'%(self.redundant_dir))
+        else:
+            print('检验完毕,xml文件和jpg文件正常!')
+        print('\n-----------------------------step1完成-----------------------------\n')
+
+if __name__ == "__main__":
+    checkFile().main()
+
+
+
+
+
+
+
+
+
+
+

+ 118 - 0
code/data_manage/test_util/Qt5/checkfile/utils/classFeatureImages_step3.py

@@ -0,0 +1,118 @@
+"""
+# File       : classFeatureImages_step3.py
+# Time       :21.5.28 14:35
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:将标注后的图像中的特征图像提取出来后按照类别分类 step3_截取图片中标注后的特征图像,并对其分类处理
+"""
+import math
+import cv2
+import shutil
+import time
+import numpy as np
+from tqdm import tqdm
+from utils import coorUtil, profileUtil, pathUtil, fileUtil
+import threading
+
+class classFeatureImages:
+
+    def __init__(self, path):
+        self.img_dir = pathUtil.path_format(path)  # 图片保存路径
+        self.redundant_dir = pathUtil.path_format_join(path, "redundant_data")
+        self.class_img_dir = pathUtil.path_format_join(path, "class_img_dir") # 分类好的图片保存路径
+
+
+    def cv_imread(self, filePath):
+        """
+        读取图像,该方法解决了cv2.imread()不能读取中文路径的问题
+        :param filePath: 文件路径
+        :return:
+        """
+        cv_img = cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), -1)
+        return cv_img
+
+
+    @staticmethod
+    def splitdf(df, num):
+        linenum = math.floor(len(df) / num)
+        pdlist = []
+        for i in range(num):
+            pd1 = df[i * linenum:(i + 1) * linenum]
+            pdlist.append(pd1)
+        pd1 = df[(num - 1) * linenum:len(df)]
+        pdlist.append(pd1)
+        return pdlist
+
+    def isexists_class(self, df, class_img_dir):
+        """
+        通过['class_name']字段先获取所有图像的类别,然后判断该类别文件夹是否存在,如果不存在则新建该类别文件夹
+        :param df: Dataframe
+        :param class_img_dir:类别文件夹的父目录
+        :return:
+        """
+        group_df = df.groupby('class_name')
+        for k, _ in group_df:
+            class_dir = pathUtil.path_format_join(class_img_dir, str(k))
+            pathUtil.mkdir(class_dir)
+
+
+    def gen_class_img_thread(self, csv_df, img_dir, class_img_dir):
+        self.isexists_class(csv_df, class_img_dir)
+        threads = []
+        df_list = self.splitdf(csv_df, 10)
+        for df in df_list:
+            t =threading.Thread(target=self._gen_class_img, args=(df, img_dir, class_img_dir))
+            threads.append(t)
+        for t in threads:
+            t.setDaemon(True)
+            t.start()
+        for t in threads:
+            t.join()
+        time.sleep(10)
+
+        return len(csv_df)
+
+    def _gen_class_img(self, csv_df, img_dir, class_img_dir):
+        """
+        将所有图片中标注出来的特征图片切分出来存放入对应类别的文件夹
+        :param img_dir: 需要读取原始图片的路径
+        :param class_img_dir: 存放分类的特征图片保存路径的父路径,与class_name拼接后生成保存图像文件的绝对路径
+        :return:
+        """
+        errors = []
+        for index, row in tqdm(csv_df.iterrows(),total=len(csv_df), ncols=60):
+            filename, class_name = row["filename"], row["class_name"]
+            image_path = pathUtil.path_format_join(img_dir, filename)
+            error, error_dic = coorUtil.check_coor(image_path, row)
+            if error:
+                old_file_path  = '.'.join(image_path.split('.')[0:-1])
+                new_file = '.'.join(image_path.split('\\')[-1].split('/')[-1].split('.')[0:-1])
+                new_file_path = pathUtil.path_format_join(self.redundant_dir,new_file)
+                try:
+                    shutil.move(old_file_path+'.jpg', new_file_path+'.jpg')
+                    shutil.move(old_file_path+'.xml', new_file_path+'.xml')
+                    errors.extend([error_dic])
+                except:
+                    pass
+                continue
+            xmin, ymin, xmax, ymax = row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+            class_file = pathUtil.path_format_join(class_img_dir, "{}".format(row["class_name"]))
+            image = self.cv_imread(image_path)
+            cropimg = image[int(ymin):int(ymax), int(xmin):int(xmax)]
+            img_path = pathUtil.path_format_join(class_file, filename.split('.jpg')[0] + '_' + str(xmin) + '_' + str(ymin) + '_' + str(xmax) + '_' + str(ymax) + '.jpg')
+            cv2.imwrite(img_path, cropimg)
+        if errors:
+            print('标注图像有问题:', errors)
+            fileUtil.writelog(errors)
+
+    def main(self):
+        pathUtil.mkdir_new(self.class_img_dir)
+        csv_df = profileUtil.xmlUtil().xml_parse(self.img_dir)
+        len_df = self.gen_class_img_thread(csv_df, self.img_dir,self.class_img_dir)
+        print('\n-----------------------------step2完成-----------------------------\n')
+        return len_df
+
+#
+# if __name__ == '__main__':
+#     classFeatureImages().main()

+ 75 - 0
code/data_manage/test_util/Qt5/checkfile/utils/coorUtil.py

@@ -0,0 +1,75 @@
+"""
+# File       : coorUtil.py
+# Time       :21.6.8 17:28
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:检查标注框在原图片中的位置是否出现超越背景图片总长宽的情况
+"""
+
+import cv2
+
+def check_coor(image_path, row):
+    error_dic = {}
+    width, height, xmin, ymin, xmax, ymax = row["width"], row["height"], row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+    img = cv2.imread(image_path)
+    error = False
+    if type(img) == type(None):
+        error = True
+        code = str('Could not read image')
+        error_dic[image_path] = code
+        return error, error_dic
+
+    org_height, org_width = img.shape[:2]
+
+    if org_width != width:
+        error = True
+        code = str('Width mismatch for image: ' + str(width) + '!=' + str(org_width))
+        error_dic[image_path] = code
+
+    if org_height != height:
+        error = True
+        code = str('Height mismatch for image: ' + str(height) + '!=' + str(org_height))
+        error_dic[image_path] = code
+
+    if xmin > org_width:
+        error = True
+        code = str('XMIN > org_width for file')
+        error_dic[image_path] = code
+
+    if xmin < 0:
+        error = True
+        code = str('XMIN < 0 for file')
+        error_dic[image_path] = code
+
+    if xmax > org_width:
+        error = True
+        code = str('XMAX > org_width for file')
+        error_dic[image_path] = code
+
+    if ymin > org_height:
+        error = True
+        code = str('YMIN > org_height for file')
+        error_dic[image_path] = code
+
+    if ymin < 0:
+        error = True
+        code = str('YMIN < 0 for file')
+        error_dic[image_path] = code
+
+    if ymax > org_height:
+        error = True
+        code = str('YMAX > org_height for file')
+        error_dic[image_path] = code
+
+    if xmin >= xmax:
+        error = True
+        code = str('xmin >= xmax for file')
+        error_dic[image_path] = code
+
+    if ymin >= ymax:
+        error = True
+        code = str('ymin >= ymax for file')
+        error_dic[image_path] = code
+
+    return error, error_dic

+ 17 - 0
code/data_manage/test_util/Qt5/checkfile/utils/fileUtil.py

@@ -0,0 +1,17 @@
+"""
+# File       : fileUtil.py
+# Time       :21.5.25 18:40
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:文件筛选工具类
+"""
+import time
+
+def writelog(data):
+    now = time.strftime('%Y-%m-%d %H:%M', time.localtime(time.time()))
+    with open('../error_log.txt', 'a') as f:
+        f.write("===============================%s======================================\n"%(now))
+        f.writelines(str(data))
+        f.write("\n")
+

+ 31 - 0
code/data_manage/test_util/Qt5/checkfile/utils/pathUtil.py

@@ -0,0 +1,31 @@
+"""
+# File       : pathUtil.py
+# Time       :21.5.25 18:13
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:路径操作工具
+"""
+import os
+import shutil
+
+def mkdir(new_folder):
+    if not os.path.exists(new_folder):
+        os.makedirs(new_folder)
+
+def mkdir_new(new_folder):
+    if os.path.exists(new_folder):
+        shutil.rmtree(new_folder, ignore_errors=True)
+    os.makedirs(new_folder)
+
+def path_format(path_str):
+    path = path_str.replace('\\','/').replace('"', '').replace('\'', '')
+    if str(path).endswith('/'):
+        return '/'.join(path[0:-1])
+    else:
+        return path
+
+def path_format_join(path_str1, path_str2):
+    return os.path.join(path_format(path_str1), path_format(path_str2)).replace('\\','/')
+
+

+ 53 - 0
code/data_manage/test_util/Qt5/checkfile/utils/profileUtil.py

@@ -0,0 +1,53 @@
+"""
+# File       : profileUtil.py
+# Time       :21.5.25 16:16
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:配置文件操作工具
+"""
+
+import glob
+import pandas as pd
+import xml.etree.ElementTree as ET
+from utils import fileUtil, pathUtil
+
+
+class xmlUtil:
+    """
+    xml文件工具类
+    """
+
+    def xml_parse(self, data_dir, find='filename'):
+        """
+
+        :param data_dir: xml file path
+        :return: dataframe (filename, path, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+        """
+        error_xml_list = []
+        xml_list = []
+        for xml_file in glob.glob(data_dir + '/*.xml'):
+            xml_file_path = pathUtil.path_format(xml_file)
+            try:
+                tree = ET.parse(xml_file_path)
+            except:
+                error_xml_list.append(xml_file_path)
+                continue
+            root = tree.getroot()
+            filename = root.find(find).text
+            width, height, depth = int(root.find('size')[0].text), int(root.find('size')[1].text), int(
+                root.find('size')[2].text)
+            for member in root.findall('object'):
+                class_name = member[0].text.upper()
+                x_min, y_min, x_max, y_max = int(member[4][0].text), int(member[4][1].text), int(
+                    member[4][2].text), int(member[4][3].text)
+                value = (filename, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+                xml_list.append(value)
+
+
+        column_name = ['filename', 'width', 'height','depth' ,'class_name', 'xmin', 'ymin', 'xmax', 'ymax']
+        xml_df = pd.DataFrame(xml_list, columns=column_name)
+        if error_xml_list:
+            fileUtil.writelog(error_xml_list)
+            print('解析错误的xml文件有:%s'%(error_xml_list))
+        return xml_df

File diff suppressed because it is too large
+ 3 - 0
code/data_manage/test_util/Qt5/error_log.txt


+ 151 - 0
code/data_manage/test_util/Qt5/excel/ModifyTree.py

@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+import sys
+import time
+from PyQt5.QtWidgets import *
+from PyQt5.QtGui import QIcon
+from PyQt5.QtCore import Qt
+from PyQt5.QtCore import QSize
+import pandas as pd
+import xlsxwriter
+from PIL import Image
+
+
+  
+class TreeWidget(QWidget):
+    """
+    QT界面设置
+
+    """
+    def __init__(self):
+
+        self.read_excel_path = 'data/image.xlsx'
+        self.tree_list = []
+        self.df_excel = pd.read_excel(self.read_excel_path)
+        super(TreeWidget, self).__init__()
+        self.setWindowTitle('标注说明文档导出')
+        self.resize(900, 600)
+        
+        self.tree = QTreeWidget()  # 实例化一个TreeWidget对象
+        self.tree.setColumnCount(2)  # 设置部件的列数为2
+        self.tree.setDropIndicatorShown(True)
+
+
+        self.tree.setColumnWidth(0, 200)
+        self.tree.setIconSize(QSize(100, 100))
+
+        self.tree.setSelectionMode(QAbstractItemView.ExtendedSelection)#设置item可以多选
+        self.tree.setHeaderLabels(['特征类别', '示例图片'])  # 设置头部信息对应列的标识符
+
+        # 设置root为self.tree的子树,故root是根节点
+        root = QTreeWidgetItem(self.tree)
+        root.setText(0, '特征标注')  # 设置根节点的名称
+
+
+
+        # 为root节点设置子结点
+        df_excel_group = self.group_by(self.df_excel, 'class')
+        for k, values in df_excel_group:
+            child = QTreeWidgetItem(root)
+            child.setText(0, k)
+            child.setExpanded(True)
+            for key, value in self.group_by(values, 'code'):
+                child2 = QTreeWidgetItem(child)
+                child2.setText(0, str(key))
+                child2.setCheckState(0, Qt.Unchecked)
+                child2.setExpanded(True)
+                for k, v in value.iterrows():
+                    child3 = QTreeWidgetItem(child2)
+                    child3.setText(0, v['code_num'])
+                    child3.setIcon(1, QIcon(v['img_path']))
+
+                    # child2.setIcon(0, QIcon(v['img_path']))
+
+        button=QPushButton("导出")
+        lay=QVBoxLayout()
+        lay.addWidget(button)
+        lay.addWidget(self.tree)
+
+        button.clicked.connect(self.write_excel)
+        button.clicked.connect(self.show_message)
+        
+        self.tree.itemChanged.connect(self.handleChanged)
+        
+        self.tree.addTopLevelItem(root)
+        self.setLayout(lay)  # 将tree部件设置为该窗口的核心框架
+        
+    def handleChanged(self, item, column):
+        #当check状态改变时得到他的状态。通过复选框将选中的元素加入到列表中,方便后续导出
+        if item.checkState(column) == Qt.Checked:
+            self.tree_list.extend([item.text(column)])
+        if item.checkState(column) == Qt.Unchecked:
+            self.tree_list.remove(item.text(column))
+
+    def show_message(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "消息提示","excel文件导出成功,程序退出")
+        app.quit()
+
+    def group_by(self, df, st):
+        """DdatFrame分组"""
+        return df.groupby(st)
+
+
+    def write_excel(self):
+        """
+        写入到excel中
+        :return:
+        """
+        file_name = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))+'.xlsx'
+        format = {
+            'bold': True,  # 字体加粗
+            'align': 'center',  # 水平位置设置:居中
+            'valign': 'vcenter',
+            "fg_color":'C4C4C4', # 垂直位置设置,居中
+            'font_size': 14,  # '字体大小设置'
+        }
+
+        workbook = xlsxwriter.Workbook(file_name)
+        str_format = workbook.add_format(format)
+        worksheet = workbook.add_worksheet('sheet1')
+
+        worksheet.write(0, 0, '类别', str_format)
+        worksheet.write(0, 1, '代码', str_format)
+        worksheet.write(0, 2, '代码状态', str_format)
+        worksheet.write(0, 3, '注释', str_format)
+        worksheet.write(0, 4, '示例图片', str_format)
+
+        worksheet.set_column_pixels(1, 10, 200)
+        worksheet.set_column('E1:E10', 50)
+        count = 0
+        for id, df in self.df_excel.iterrows():
+            if df['code'] in self.tree_list:
+                count += 1
+                try:
+                    im = Image.open(str(df['img_path']))
+                    height = im.height + 15
+                    worksheet.set_row_pixels(count, height)
+                    worksheet.write(count, 0, str(df['class']))
+                    worksheet.write(count, 1, str(df['code']))
+                    worksheet.write(count, 2, str(df['code_num']))
+                    worksheet.write(count, 3, str(df['note']))
+                    worksheet.insert_image(count, 4, str(df['img_path']))
+                except:
+                    worksheet.write(count, 0, str(df['class']))
+                    worksheet.write(count, 1, str(df['code']))
+                    worksheet.write(count, 2, str(df['code_num']))
+                    worksheet.write(count, 3, str(df['note']))
+                    worksheet.write(count, 4, str(df['img_path']))
+
+        workbook.close()
+
+
+if __name__ == "__main__":
+
+    app = QApplication(sys.argv)
+    app.aboutToQuit.connect(app.deleteLater)
+    tp = TreeWidget()
+    tp.show()
+    app.exec_()

+ 8 - 0
code/data_manage/test_util/Qt5/excel/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py.py
+# Time       :21.6.24 15:27
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 5 - 0
code/data_manage/test_util/Qt5/excel/data/excel.conf

@@ -0,0 +1,5 @@
+[excel]
+# 读取存放标注数据信息的excel表格路径
+read_excel_path = "data/image.xlsx"
+# 一键导出excel表格的保存路径
+write_excel_path = "data/宁德时代.xlsx"

BIN
code/data_manage/test_util/Qt5/excel/data/image.xlsx


+ 252 - 0
code/data_manage/test_util/Qt5/frame.py

@@ -0,0 +1,252 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'frame.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.4
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again.  Do not edit this file unless you know what you are doing.
+
+import os
+import sys
+import shutil
+import pandas as pd
+import base64
+import mysql.connector
+from PyQt5.QtWidgets import QFileDialog, QMessageBox
+from PyQt5 import QtCore, QtGui, QtWidgets
+
+class Ui_Form(QtWidgets.QMainWindow):
+    def __init__(self):
+        super(Ui_Form, self).__init__()
+        self.setupUi(self)
+        self.retranslateUi(self)
+        self.cwd = os.getcwd()
+
+    def setupUi(self, Form):
+        Form.setObjectName("Form")
+        Form.resize(546, 341)
+        self.layoutWidget = QtWidgets.QWidget(Form)
+        self.layoutWidget.setGeometry(QtCore.QRect(5, 180, 518, 111))
+        self.layoutWidget.setObjectName("layoutWidget")
+        self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.layoutWidget)
+        self.verticalLayout_2.setContentsMargins(0, 0, 0, 0)
+        self.verticalLayout_2.setObjectName("verticalLayout_2")
+        self.horizontalLayout_4 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_4.setObjectName("horizontalLayout_4")
+        self.label_9 = QtWidgets.QLabel(self.layoutWidget)
+        self.label_9.setObjectName("label_9")
+        self.horizontalLayout_4.addWidget(self.label_9)
+        self.lineEdit_9 = QtWidgets.QLineEdit(self.layoutWidget)
+        self.lineEdit_9.setObjectName("lineEdit_9")
+        self.horizontalLayout_4.addWidget(self.lineEdit_9)
+        self.verticalLayout_2.addLayout(self.horizontalLayout_4)
+        self.pushButton_1 = QtWidgets.QPushButton(self.layoutWidget)
+        self.pushButton_1.setObjectName("pushButton_1")
+        self.verticalLayout_2.addWidget(self.pushButton_1)
+        self.label_17 = QtWidgets.QLabel(self.layoutWidget)
+        self.label_17.setObjectName("label_17")
+        self.verticalLayout_2.addWidget(self.label_17)
+        self.splitter = QtWidgets.QSplitter(Form)
+        self.splitter.setGeometry(QtCore.QRect(5, 0, 516, 181))
+        self.splitter.setOrientation(QtCore.Qt.Vertical)
+        self.splitter.setObjectName("splitter")
+        self.label_15 = QtWidgets.QLabel(self.splitter)
+        self.label_15.setObjectName("label_15")
+        self.widget = QtWidgets.QWidget(self.splitter)
+        self.widget.setObjectName("widget")
+        self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.widget)
+        self.horizontalLayout_3.setContentsMargins(0, 0, 0, 0)
+        self.horizontalLayout_3.setObjectName("horizontalLayout_3")
+        self.label_3 = QtWidgets.QLabel(self.widget)
+        self.label_3.setObjectName("label_3")
+        self.horizontalLayout_3.addWidget(self.label_3)
+        self.lineEdit_3 = QtWidgets.QLineEdit('127.0.0.1', self.widget)
+        self.lineEdit_3.setObjectName("lineEdit_3")
+        self.horizontalLayout_3.addWidget(self.lineEdit_3)
+        self.label_4 = QtWidgets.QLabel(self.widget)
+        self.label_4.setObjectName("label_4")
+        self.horizontalLayout_3.addWidget(self.label_4)
+        self.lineEdit_4 = QtWidgets.QLineEdit('3360',self.widget)
+        self.lineEdit_4.setObjectName("lineEdit_4")
+        self.horizontalLayout_3.addWidget(self.lineEdit_4)
+        self.widget1 = QtWidgets.QWidget(self.splitter)
+        self.widget1.setObjectName("widget1")
+        self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.widget1)
+        self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0)
+        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
+        self.label_5 = QtWidgets.QLabel(self.widget1)
+        self.label_5.setObjectName("label_5")
+        self.horizontalLayout_2.addWidget(self.label_5)
+        self.lineEdit_5 = QtWidgets.QLineEdit('root', self.widget1)
+        self.lineEdit_5.setObjectName("lineEdit_5")
+        self.horizontalLayout_2.addWidget(self.lineEdit_5)
+        self.label_6 = QtWidgets.QLabel(self.widget1)
+        self.label_6.setObjectName("label_6")
+        self.horizontalLayout_2.addWidget(self.label_6)
+        self.lineEdit_6 = QtWidgets.QLineEdit('root', self.widget1)
+        self.lineEdit_6.setObjectName("lineEdit_6")
+        self.horizontalLayout_2.addWidget(self.lineEdit_6)
+        self.widget2 = QtWidgets.QWidget(self.splitter)
+        self.widget2.setObjectName("widget2")
+        self.horizontalLayout_10 = QtWidgets.QHBoxLayout(self.widget2)
+        self.horizontalLayout_10.setContentsMargins(0, 0, 0, 0)
+        self.horizontalLayout_10.setObjectName("horizontalLayout_10")
+        self.horizontalLayout = QtWidgets.QHBoxLayout()
+        self.horizontalLayout.setObjectName("horizontalLayout")
+        self.label_7 = QtWidgets.QLabel(self.widget2)
+        self.label_7.setObjectName("label_7")
+        self.horizontalLayout.addWidget(self.label_7)
+        self.lineEdit_7 = QtWidgets.QLineEdit('db_img', self.widget2)
+        self.lineEdit_7.setObjectName("lineEdit_7")
+        self.horizontalLayout.addWidget(self.lineEdit_7)
+        self.horizontalLayout_10.addLayout(self.horizontalLayout)
+        self.horizontalLayout_9 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_9.setObjectName("horizontalLayout_9")
+        self.label_8 = QtWidgets.QLabel(self.widget2)
+        self.label_8.setObjectName("label_8")
+        self.horizontalLayout_9.addWidget(self.label_8)
+        self.lineEdit_8 = QtWidgets.QLineEdit(self.widget2)
+        self.lineEdit_8.setObjectName("lineEdit_8")
+        self.horizontalLayout_9.addWidget(self.lineEdit_8)
+        self.horizontalLayout_10.addLayout(self.horizontalLayout_9)
+        self.pushButton = QtWidgets.QPushButton(self.splitter)
+        self.pushButton.setObjectName("pushButton")
+        self.label_16 = QtWidgets.QLabel(self.splitter)
+        self.label_16.setObjectName("label_16")
+
+        self.pushButton.clicked.connect(self.push_button)
+        self.pushButton_1.clicked.connect(self.push_button_1)
+
+
+        self.retranslateUi(Form)
+        QtCore.QMetaObject.connectSlotsByName(Form)
+
+    def retranslateUi(self, Form):
+        _translate = QtCore.QCoreApplication.translate
+        Form.setWindowTitle(_translate("Form", "Form"))
+        self.label_9.setText(_translate("Form", "筛选条件(target_id):"))
+        self.pushButton_1.setText(_translate("Form", "生成数据"))
+        self.label_17.setText(_translate("Form", "--------------------------------------------------------------------------------------"))
+        self.label_15.setText(_translate("Form", "数据库配置信息"))
+        self.label_3.setText(_translate("Form", "Host:  "))
+        self.label_4.setText(_translate("Form", "Port:     "))
+        self.label_5.setText(_translate("Form", "User:  "))
+        self.label_6.setText(_translate("Form", "Password: "))
+        self.label_7.setText(_translate("Form", "dbName:"))
+        self.label_8.setText(_translate("Form", "tableName:"))
+        self.pushButton.setText(_translate("Form", "数据库连接测试"))
+        self.label_16.setText(_translate("Form", "--------------------------------------------------------------------------------------"))
+
+    def connect_database_success(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数据库提示", "数据库连接成功!")
+
+    def select_database_success(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数据库提示", "数据查询完成!")
+
+    def connect_database_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数据库错误提示", "数据库连接错误!请检查")
+
+    def show_filter_message_list_len_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "错误提示", "筛选条件文本框错误,请检查!")
+
+    @staticmethod
+    def lineEdit_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(','):
+            lists.extend(lines.split(','))
+        return lists
+
+    @staticmethod
+    def my_db_connect(host, user, passwd, database):
+        """
+        生成mysql连接
+        :return:
+        """
+        my_db = mysql.connector.connect(host=host, user=user, passwd=passwd, database=database, buffered=True)
+        return my_db
+
+    #host, user, passwd, database
+    def push_button(self):
+        try:
+            self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(), self.lineEdit_7.text())
+            self.connect_database_success()
+        except:
+            self.connect_database_error()
+
+    @staticmethod
+    def select_target_id(mydb, lineEdit_8, target_id_list):
+        str_id = "select target_id, zoomin_pic, param_1 from {} where".format(lineEdit_8)
+        for id in target_id_list:
+            str_id = str_id + " target_id="+"'" + id + "'" + " or "
+        sql = str_id.strip()[:-3]
+
+        mycursor = mydb.cursor()
+        mycursor.execute(sql)
+        result = mycursor.fetchall()
+        mycursor.close()
+        return result
+
+    def write_jpg_csv(self, results):
+        path = self.cwd+'/results'
+        if os.path.exists(path):
+            shutil.rmtree(path, ignore_errors=True)
+        os.mkdir(path)
+        text_file = open(path+'/results.txt', 'a')
+        for result in results:
+            text = str(result[0])+','+str(result[2])+'\n'
+            text_file.writelines(text)
+            img = str(result[1])
+            imgdata = base64.b64decode(img)
+            file_name = path+'/'+str(result[0])+'.jpg'
+            file = open(file_name, 'wb')
+            file.write(imgdata)
+            file.close()
+        text_file.close()
+
+
+    def push_button_1(self):
+        if self.lineEdit_9.text() and self.lineEdit_8.text():
+            target_id_list = self.lineEdit_str2list(self.lineEdit_9.text())
+            my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(), self.lineEdit_7.text())
+            try:
+                results = self.select_target_id(my_db, self.lineEdit_8.text(), target_id_list)
+                self.write_jpg_csv(results)
+
+            except:
+                self.show_database_error()
+            my_db.close()
+            self.select_database_success()
+        else:
+            self.show_filter_message_list_len_error()
+
+
+
+if __name__ == "__main__":
+    app = QtWidgets.QApplication(sys.argv)
+    QMainWindow = QtWidgets.QMainWindow()
+    ui = Ui_Form()
+    ui.setupUi(QMainWindow)
+    QMainWindow.show()
+    sys.exit(app.exec_())
+
+
+

+ 167 - 0
code/data_manage/test_util/Qt5/frame.ui

@@ -0,0 +1,167 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<ui version="4.0">
+ <class>Form</class>
+ <widget class="QWidget" name="Form">
+  <property name="geometry">
+   <rect>
+    <x>0</x>
+    <y>0</y>
+    <width>546</width>
+    <height>341</height>
+   </rect>
+  </property>
+  <property name="windowTitle">
+   <string>Form</string>
+  </property>
+  <widget class="QWidget" name="layoutWidget">
+   <property name="geometry">
+    <rect>
+     <x>0</x>
+     <y>180</y>
+     <width>518</width>
+     <height>111</height>
+    </rect>
+   </property>
+   <layout class="QVBoxLayout" name="verticalLayout_2">
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_4">
+      <item>
+       <widget class="QLabel" name="label_9">
+        <property name="text">
+         <string>筛选条件(target_id):</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_9"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <widget class="QPushButton" name="pushButton_1">
+      <property name="text">
+       <string>生成数据</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_17">
+      <property name="text">
+       <string>--------------------------------------------------------------------------------------</string>
+      </property>
+     </widget>
+    </item>
+   </layout>
+  </widget>
+  <widget class="QSplitter" name="splitter">
+   <property name="geometry">
+    <rect>
+     <x>0</x>
+     <y>0</y>
+     <width>516</width>
+     <height>181</height>
+    </rect>
+   </property>
+   <property name="orientation">
+    <enum>Qt::Vertical</enum>
+   </property>
+   <widget class="QLabel" name="label_15">
+    <property name="text">
+     <string>数据库配置信息</string>
+    </property>
+   </widget>
+   <widget class="QWidget" name="">
+    <layout class="QHBoxLayout" name="horizontalLayout_3">
+     <item>
+      <widget class="QLabel" name="label_3">
+       <property name="text">
+        <string>Host:  </string>
+       </property>
+      </widget>
+     </item>
+     <item>
+      <widget class="QLineEdit" name="lineEdit_3"/>
+     </item>
+     <item>
+      <widget class="QLabel" name="label_4">
+       <property name="text">
+        <string>Port:     </string>
+       </property>
+      </widget>
+     </item>
+     <item>
+      <widget class="QLineEdit" name="lineEdit_4"/>
+     </item>
+    </layout>
+   </widget>
+   <widget class="QWidget" name="">
+    <layout class="QHBoxLayout" name="horizontalLayout_2">
+     <item>
+      <widget class="QLabel" name="label_5">
+       <property name="text">
+        <string>User:  </string>
+       </property>
+      </widget>
+     </item>
+     <item>
+      <widget class="QLineEdit" name="lineEdit_5"/>
+     </item>
+     <item>
+      <widget class="QLabel" name="label_6">
+       <property name="text">
+        <string>Password: </string>
+       </property>
+      </widget>
+     </item>
+     <item>
+      <widget class="QLineEdit" name="lineEdit_6"/>
+     </item>
+    </layout>
+   </widget>
+   <widget class="QWidget" name="">
+    <layout class="QHBoxLayout" name="horizontalLayout_10">
+     <item>
+      <layout class="QHBoxLayout" name="horizontalLayout">
+       <item>
+        <widget class="QLabel" name="label_7">
+         <property name="text">
+          <string>dbName:</string>
+         </property>
+        </widget>
+       </item>
+       <item>
+        <widget class="QLineEdit" name="lineEdit_7"/>
+       </item>
+      </layout>
+     </item>
+     <item>
+      <layout class="QHBoxLayout" name="horizontalLayout_9">
+       <item>
+        <widget class="QLabel" name="label_8">
+         <property name="text">
+          <string>tableName:</string>
+         </property>
+        </widget>
+       </item>
+       <item>
+        <widget class="QLineEdit" name="lineEdit_8"/>
+       </item>
+      </layout>
+     </item>
+    </layout>
+   </widget>
+   <widget class="QPushButton" name="pushButton">
+    <property name="text">
+     <string>数据库连接测试</string>
+    </property>
+   </widget>
+   <widget class="QLabel" name="label_16">
+    <property name="text">
+     <string>--------------------------------------------------------------------------------------</string>
+    </property>
+   </widget>
+  </widget>
+ </widget>
+ <resources/>
+ <connections/>
+</ui>

+ 524 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/frame.py

@@ -0,0 +1,524 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'frame.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.4
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again.  Do not edit this file unless you know what you are doing.
+import os
+import sys
+import pandas as pd
+import mysql.connector
+from PyQt5 import QtCore, QtWidgets
+from gen_data.utils import select_data_util, write_xml_jpg_2_all_data_util, pathUtil
+from PyQt5.QtWidgets import QFileDialog, QMessageBox
+from gen_data import check_data_step2, classFeatureImages_step3, gen_class_index_step4, genAnn_step5, \
+    splitTrainVal_step6, gen_pb_tfrecord_step7
+
+
+class Ui_Form(QtWidgets.QMainWindow):
+    def __init__(self):
+        super(Ui_Form, self).__init__()
+        self.setupUi(self)
+        self.retranslateUi(self)
+        self.cwd = os.getcwd()
+
+    def setupUi(self, Form):
+        Form.setObjectName("Form")
+        Form.resize(648, 783)
+        self.widget = QtWidgets.QWidget(Form)
+        self.widget.setGeometry(QtCore.QRect(60, 35, 520, 616))
+        self.widget.setObjectName("widget")
+        self.verticalLayout_5 = QtWidgets.QVBoxLayout(self.widget)
+        self.verticalLayout_5.setContentsMargins(0, 0, 0, 0)
+        self.verticalLayout_5.setObjectName("verticalLayout_5")
+        self.verticalLayout = QtWidgets.QVBoxLayout()
+        self.verticalLayout.setObjectName("verticalLayout")
+        self.label_15 = QtWidgets.QLabel(self.widget)
+        self.label_15.setObjectName("label_15")
+        self.verticalLayout.addWidget(self.label_15, 0, QtCore.Qt.AlignHCenter)
+        self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_3.setObjectName("horizontalLayout_3")
+        self.label_3 = QtWidgets.QLabel(self.widget)
+        self.label_3.setObjectName("label_3")
+        self.horizontalLayout_3.addWidget(self.label_3)
+        self.lineEdit_3 = QtWidgets.QLineEdit('127.0.0.1', self.widget)
+        self.lineEdit_3.setObjectName("lineEdit_3")
+        self.horizontalLayout_3.addWidget(self.lineEdit_3)
+        self.label_4 = QtWidgets.QLabel(self.widget)
+        self.label_4.setObjectName("label_4")
+        self.horizontalLayout_3.addWidget(self.label_4)
+        self.lineEdit_4 = QtWidgets.QLineEdit('3306', self.widget)
+        self.lineEdit_4.setObjectName("lineEdit_4")
+        self.horizontalLayout_3.addWidget(self.lineEdit_4)
+        self.verticalLayout.addLayout(self.horizontalLayout_3)
+        self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
+        self.label_5 = QtWidgets.QLabel(self.widget)
+        self.label_5.setObjectName("label_5")
+        self.horizontalLayout_2.addWidget(self.label_5)
+        self.lineEdit_5 = QtWidgets.QLineEdit('root', self.widget)
+        self.lineEdit_5.setObjectName("lineEdit_5")
+        self.horizontalLayout_2.addWidget(self.lineEdit_5)
+        self.label_6 = QtWidgets.QLabel(self.widget)
+        self.label_6.setObjectName("label_6")
+        self.horizontalLayout_2.addWidget(self.label_6)
+        self.lineEdit_6 = QtWidgets.QLineEdit('root', self.widget)
+        self.lineEdit_6.setEchoMode(QtWidgets.QLineEdit.Password)
+        self.lineEdit_6.setObjectName("lineEdit_6")
+        self.horizontalLayout_2.addWidget(self.lineEdit_6)
+        self.verticalLayout.addLayout(self.horizontalLayout_2)
+        self.horizontalLayout = QtWidgets.QHBoxLayout()
+        self.horizontalLayout.setObjectName("horizontalLayout")
+        self.label_7 = QtWidgets.QLabel(self.widget)
+        self.label_7.setObjectName("label_7")
+        self.horizontalLayout.addWidget(self.label_7)
+        self.lineEdit_7 = QtWidgets.QLineEdit('db_img', self.widget)
+        self.lineEdit_7.setObjectName("lineEdit_7")
+        self.horizontalLayout.addWidget(self.lineEdit_7)
+        self.pushButton = QtWidgets.QPushButton(self.widget)
+        self.pushButton.setObjectName("pushButton")
+        self.horizontalLayout.addWidget(self.pushButton)
+        self.verticalLayout.addLayout(self.horizontalLayout)
+        self.label_8 = QtWidgets.QLabel(self.widget)
+        self.label_8.setText("")
+        self.label_8.setObjectName("label_8")
+        self.verticalLayout.addWidget(self.label_8)
+        self.label_16 = QtWidgets.QLabel(self.widget)
+        self.label_16.setObjectName("label_16")
+        self.verticalLayout.addWidget(self.label_16)
+        self.verticalLayout_5.addLayout(self.verticalLayout)
+        self.verticalLayout_2 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_2.setObjectName("verticalLayout_2")
+        self.label = QtWidgets.QLabel(self.widget)
+        self.label.setObjectName("label")
+        self.verticalLayout_2.addWidget(self.label, 0, QtCore.Qt.AlignHCenter)
+        self.label_10 = QtWidgets.QLabel(self.widget)
+        self.label_10.setObjectName("label_10")
+        self.verticalLayout_2.addWidget(self.label_10)
+        self.label_12 = QtWidgets.QLabel(self.widget)
+        self.label_12.setObjectName("label_12")
+        self.verticalLayout_2.addWidget(self.label_12)
+        self.label_13 = QtWidgets.QLabel(self.widget)
+        self.label_13.setObjectName("label_13")
+        self.verticalLayout_2.addWidget(self.label_13)
+        self.label_14 = QtWidgets.QLabel(self.widget)
+        self.label_14.setObjectName("label_14")
+        self.verticalLayout_2.addWidget(self.label_14)
+        self.label_11 = QtWidgets.QLabel(self.widget)
+        self.label_11.setObjectName("label_11")
+        self.verticalLayout_2.addWidget(self.label_11)
+        self.horizontalLayout_4 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_4.setObjectName("horizontalLayout_4")
+        self.label_9 = QtWidgets.QLabel(self.widget)
+        self.label_9.setObjectName("label_9")
+        self.horizontalLayout_4.addWidget(self.label_9)
+        self.lineEdit_9 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_9.setObjectName("lineEdit_9")
+        self.horizontalLayout_4.addWidget(self.lineEdit_9)
+        self.verticalLayout_2.addLayout(self.horizontalLayout_4)
+        self.label_27 = QtWidgets.QLabel(self.widget)
+        self.label_27.setObjectName("label_27")
+        self.verticalLayout_2.addWidget(self.label_27)
+        self.label_28 = QtWidgets.QLabel(self.widget)
+        self.label_28.setObjectName("label_28")
+        self.verticalLayout_2.addWidget(self.label_28)
+        self.horizontalLayout_7 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_7.setObjectName("horizontalLayout_7")
+        self.label_21 = QtWidgets.QLabel(self.widget)
+        self.label_21.setObjectName("label_21")
+        self.horizontalLayout_7.addWidget(self.label_21)
+        self.pushButton_21 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_21.setObjectName("pushButton_21")
+        self.horizontalLayout_7.addWidget(self.pushButton_21)
+        self.lineEdit_21 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_21.setObjectName("lineEdit_21")
+        self.horizontalLayout_7.addWidget(self.lineEdit_21)
+        self.verticalLayout_2.addLayout(self.horizontalLayout_7)
+        self.pushButton_1 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_1.setObjectName("pushButton_1")
+        self.verticalLayout_2.addWidget(self.pushButton_1)
+        self.label_26 = QtWidgets.QLabel(self.widget)
+        self.label_26.setText("")
+        self.label_26.setObjectName("label_26")
+        self.verticalLayout_2.addWidget(self.label_26)
+        self.label_17 = QtWidgets.QLabel(self.widget)
+        self.label_17.setObjectName("label_17")
+        self.verticalLayout_2.addWidget(self.label_17)
+        self.verticalLayout_5.addLayout(self.verticalLayout_2)
+        self.verticalLayout_3 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_3.setObjectName("verticalLayout_3")
+        self.label_23 = QtWidgets.QLabel(self.widget)
+        self.label_23.setObjectName("label_23")
+        self.verticalLayout_3.addWidget(self.label_23, 0, QtCore.Qt.AlignHCenter)
+        self.horizontalLayout_5 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_5.setObjectName("horizontalLayout_5")
+        self.label_18 = QtWidgets.QLabel(self.widget)
+        self.label_18.setObjectName("label_18")
+        self.horizontalLayout_5.addWidget(self.label_18)
+        self.lineEdit_18 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_18.setObjectName("lineEdit_18")
+        self.horizontalLayout_5.addWidget(self.lineEdit_18)
+        self.verticalLayout_3.addLayout(self.horizontalLayout_5)
+        self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_8.setObjectName("horizontalLayout_8")
+        self.verticalLayout_3.addLayout(self.horizontalLayout_8)
+        self.pushButton_2 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_2.setObjectName("pushButton_2")
+        self.verticalLayout_3.addWidget(self.pushButton_2)
+        self.label_2 = QtWidgets.QLabel(self.widget)
+        self.label_2.setText("")
+        self.label_2.setObjectName("label_2")
+        self.verticalLayout_3.addWidget(self.label_2)
+        self.label_22 = QtWidgets.QLabel(self.widget)
+        self.label_22.setObjectName("label_22")
+        self.verticalLayout_3.addWidget(self.label_22)
+        self.verticalLayout_5.addLayout(self.verticalLayout_3)
+        self.verticalLayout_4 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_4.setObjectName("verticalLayout_4")
+        self.label_20 = QtWidgets.QLabel(self.widget)
+        self.label_20.setEnabled(True)
+        self.label_20.setObjectName("label_20")
+        self.verticalLayout_4.addWidget(self.label_20, 0, QtCore.Qt.AlignHCenter)
+        self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_6.setObjectName("horizontalLayout_6")
+        self.label_25 = QtWidgets.QLabel(self.widget)
+        self.label_25.setObjectName("label_25")
+        self.horizontalLayout_6.addWidget(self.label_25)
+        self.pushButton_25 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_25.setObjectName("pushButton_25")
+        self.horizontalLayout_6.addWidget(self.pushButton_25)
+        self.lineEdit_25 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_25.setObjectName("lineEdit_25")
+        self.horizontalLayout_6.addWidget(self.lineEdit_25)
+        self.verticalLayout_4.addLayout(self.horizontalLayout_6)
+        self.pushButton_3 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_3.setObjectName("pushButton_3")
+        self.verticalLayout_4.addWidget(self.pushButton_3)
+        self.verticalLayout_5.addLayout(self.verticalLayout_4)
+
+        self.pushButton_21.clicked.connect(self.openfolder_images_path_lineEdit_21)
+        # self.pushButton_24.clicked.connect(self.openfolder_images_path_lineEdit_24)
+        self.pushButton_25.clicked.connect(self.openfolder_images_path_lineEdit_25)
+        self.pushButton_1.clicked.connect(self.push_button_1)
+        self.pushButton_2.clicked.connect(self.push_button_2)
+        self.pushButton_3.clicked.connect(self.push_button_3)
+
+        self.retranslateUi(Form)
+        QtCore.QMetaObject.connectSlotsByName(Form)
+
+    def retranslateUi(self, Form):
+        _translate = QtCore.QCoreApplication.translate
+        Form.setWindowTitle(_translate("Form", "Form"))
+        self.label_15.setText(_translate("Form", "数据库配置信息"))
+        self.label_3.setText(_translate("Form", "Host:  "))
+        self.label_4.setText(_translate("Form", "Port:     "))
+        self.label_5.setText(_translate("Form", "User:  "))
+        self.label_6.setText(_translate("Form", "Password: "))
+        self.label_7.setText(_translate("Form", "dbName:"))
+        self.pushButton.setText(_translate("Form", "数据库连接测试"))
+        self.label_16.setText(_translate("Form",
+                                         "--------------------------------------------------------------------------------------"))
+        self.label.setText(_translate("Form", "工具1:根据类别和项目名称提取特征数据生成tfrecord"))
+        self.label_10.setText(_translate("Form", "提示:筛选条件写入格式为 class_name,spot_id,percent。多个项目数据之间用英文分号隔开"))
+        self.label_12.setText(_translate("Form", "    class_name 类别名称(不是state_name)要获取全部,填写all"))
+        self.label_13.setText(_translate("Form", "    spot_id 地点ID,要获取全部,填写0(此时的class_name不能为all)"))
+        self.label_14.setText(_translate("Form", "    percent 获取数据的百分比(0-1],要获取全部,填写1"))
+        self.label_11.setText(_translate("Form", "    示例:E1,0,0.2;K10,12,0.3;all,1,1"))
+        self.label_9.setText(_translate("Form", "筛选条件:"))
+        self.label_27.setText(_translate("Form", "注意:数据库中图片路径位置必须和数据保存路径在同一个磁盘映射下"))
+        self.label_28.setText(_translate("Form", "例如:录入数据库的图片存放在C盘,那么数据存放路径也必须在C盘,否则将无法获取图片"))
+        self.label_21.setText(_translate("Form", "数据保存路径:"))
+        self.pushButton_21.setText(_translate("Form", "打开文件夹"))
+        self.pushButton_1.setText(_translate("Form", "工具1:生成数据"))
+        self.label_17.setText(_translate("Form",
+                                         "--------------------------------------------------------------------------------------"))
+        self.label_23.setText(_translate("Form", "工具2:查询类别数量"))
+        self.label_18.setText(_translate("Form", "类别名称:"))
+        self.pushButton_2.setText(_translate("Form", "工具2:查询数据"))
+        self.label_22.setText(_translate("Form",
+                                         "--------------------------------------------------------------------------------------"))
+        self.label_20.setText(_translate("Form", "工具3:根据项目文件夹名称获取spot_id"))
+        self.label_25.setText(_translate("Form", "项目文件夹名称:"))
+        self.pushButton_25.setText(_translate("Form", "打开文件夹"))
+        self.pushButton_3.setText(_translate("Form", "工具3:提交"))
+
+    def show_message_folder(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "空值提示", "没有获取到相应的文件夹路径!")
+
+    def show_message_isempty(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "文本框有空值错误提示", "文本框不能为空,请检查!")
+
+    def show_filter_message_list_len_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "错误提示", "筛选条件文本框错误,请检查!")
+
+    def show_filter_message_common_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "错误提示", "非通用字段文本框错误,请检查!")
+
+    def show_message_num(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数字错误提示", "输入的不是数字,或者输入的数字小于0!")
+
+    def openfolder_images_path_lineEdit_21(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '选择数据保存路径', self.cwd)
+        self.lineEdit_21.setText(openfolder_path)
+
+    def openfolder_images_path_lineEdit_24(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '选择数据保存路径', self.cwd)
+        self.lineEdit_24.setText(openfolder_path)
+
+    def openfolder_images_path_lineEdit_25(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '打开文件夹', self.cwd)
+        self.lineEdit_25.setText(openfolder_path)
+
+    def show_spot_id(self, spot_name, spot_id):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "查询结果提示", "%s对应的ID为:%s" % (spot_name, str(spot_id)))
+
+    def show_class_name_count_num(self, result):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "查询结果提示", "类别对应的数据量为:%s" % (result))
+
+    def show_database_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数据库错误提示", "数据库查询错误!请检查查询条件")
+
+    def show_message_succes(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "完成提示", "数据生成完毕!")
+
+    @staticmethod
+    def lineEdit_9_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(';'):
+            line_list = lines.split(',')
+            spot_id = int(line_list[1])
+            percent = float(line_list[2])
+            if lines[0] == 'all' and spot_id == 0:
+                return None
+
+            if len(line_list) == 3 and spot_id >= 0 and 0 < percent <= 1:
+                lists.append([line_list[0].upper(), int(spot_id), float(percent)])
+            else:
+                return None
+
+        return lists
+
+    @staticmethod
+    def lineEdit_18_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(','):
+            lists.extend(lines.split(','))
+        return lists
+
+    @staticmethod
+    def my_db_connect(host, user, passwd, database):
+        """
+        生成mysql连接
+        :return:
+        """
+        my_db = mysql.connector.connect(host=host, user=user, passwd=passwd, database=database, buffered=True)
+        return my_db
+
+    def push_button_1(self):
+        global filter_list
+        if self.lineEdit_9.text() and self.lineEdit_21.text():
+            try:
+                filter_list = self.lineEdit_9_str2list(self.lineEdit_9.text())
+            except:
+                self.show_filter_message_list_len_error()
+            if filter_list:
+                all_df = pd.DataFrame()
+                my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(),
+                                           self.lineEdit_7.text())
+                all_df = self.list_items_select_data_2_df(my_db, filter_list, all_df)
+
+                write_xml_jpg_2_all_data_util.copy_jpg_thread(all_df, pathUtil.path_format(self.lineEdit_21.text()))
+                self.step2_to_step7(self.lineEdit_21.text())
+            # k16,1,0.3
+            else:
+                self.show_filter_message_list_len_error()
+
+        else:
+            self.show_message_isempty()
+
+    # def push_button_2(self):
+    #     if self.lineEdit_19.text() and self.lineEdit_18.text() and self.lineEdit_24.text():
+    #
+    #         try:
+    #             filter_list = self.lineEdit_18_str2list(self.lineEdit_18.text())
+    #             if filter_list:
+    #                 try:
+    #                     percent = float(self.lineEdit_19.text())
+    #                     if 0 < percent <= 1:
+    #
+    #                         print(filter_list)
+    #                     else:
+    #                         self.show_message_num()
+    #                 except:
+    #                     self.show_message_num()
+    #             else:
+    #                 self.show_filter_message_common_error()
+    #         except:
+    #             self.show_filter_message_common_error()
+    #     else:
+    #         self.show_message_isempty()
+
+    def push_button_2(self):
+        global filter_list
+        if self.lineEdit_18.text():
+            my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(),
+                                       self.lineEdit_7.text())
+            try:
+                result_str = self.select_class_name_count_num(my_db, self.lineEdit_18.text())
+                self.show_class_name_count_num(result_str)
+            except:
+                self.show_database_error()
+        else:
+            self.show_message_isempty()
+
+    def push_button_3(self):
+        if self.lineEdit_25.text():
+            spot_name = self.lineEdit_25.text().split('/')[-1].split('\\')[-1]
+            my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(),
+                                       self.lineEdit_7.text())
+            try:
+                spot_id = select_data_util.select_spot_name_get_spot_id(my_db, spot_name)
+                self.show_spot_id(spot_name, spot_id)
+            except:
+                self.show_database_error()
+            my_db.close()
+        else:
+            self.show_message_isempty()
+
+    def list_items_select_data_2_df(self, my_db, lists, all_df):
+        for li in lists:
+            if li[0].upper() == 'ALL':
+                df = self.select_spot_id_2_df(my_db, li[1], li[2])
+                all_df = all_df.append(df)
+            elif li[1] == 0:
+                df = self.select_class_name_2_df(my_db, li[0], li[2])
+                all_df = all_df.append(df)
+            else:
+                df = self.select_class_name_and_spot_id_2_df(my_db, li[0], li[1], li[2])
+                all_df = all_df.append(df)
+        return all_df
+
+    @staticmethod
+    def select_class_name_and_spot_id_2_df(mydb, class_name, spot_id, percent):
+
+        if '-' in class_name:
+            state_name_list = select_data_util.select_state_name_and_spot_id(mydb, class_name, spot_id)
+        else:
+            state_name_list = select_data_util.select_class_name_and_spot_id(mydb, class_name, spot_id)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def select_class_name_2_df(mydb, class_name, percent):
+
+        if '-' in class_name:
+            state_name_list = select_data_util.select_state_name(mydb, class_name)
+        else:
+            state_name_list = select_data_util.select_class_name(mydb, class_name)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def select_class_name_count_num(mydb, class_name):
+        results_str = ''
+        if '-' in class_name:
+            result = select_data_util.select_state_name_count_num(mydb, class_name.upper())
+            results_str = results_str + ',' + '%s:%d' % (class_name, result[0])
+        else:
+            state_name_list = select_data_util.select_class_state_class_name(mydb, class_name.upper())
+            print(state_name_list)
+            for state_name in state_name_list:
+                result = select_data_util.select_state_name_count_num(mydb, state_name[0].upper())
+                results_str = results_str + ',' + '%s:%d' % (state_name[0], result[0])
+        results_str = results_str[1:]
+        return results_str
+
+    @staticmethod
+    def select_spot_id_2_df(mydb, spot_id, percent):
+
+        state_name_list = select_data_util.select_spot_id(mydb, spot_id)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def step2_to_step7(path):
+
+        check_data_step2.checkFile(path).main()
+        classFeatureImages_step3.classFeatureImages(path).main()
+        gen_class_index_step4.genClassIndex(path).main()
+        genAnn_step5.ganAnn(path).main()
+        splitTrainVal_step6.splitTrainVal(path).main()
+        gen_pb_tfrecord_step7.genPb(path).main()
+        gen_pb_tfrecord_step7.genTfrecord(path).main()
+
+
+if __name__ == "__main__":
+    app = QtWidgets.QApplication(sys.argv)
+    QMainWindow = QtWidgets.QMainWindow()
+    ui = Ui_Form()
+    ui.setupUi(QMainWindow)
+    QMainWindow.show()
+    sys.exit(app.exec_())

+ 332 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/frame.ui

@@ -0,0 +1,332 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<ui version="4.0">
+ <class>Form</class>
+ <widget class="QWidget" name="Form">
+  <property name="geometry">
+   <rect>
+    <x>0</x>
+    <y>0</y>
+    <width>648</width>
+    <height>765</height>
+   </rect>
+  </property>
+  <property name="windowTitle">
+   <string>Form</string>
+  </property>
+  <widget class="QWidget" name="">
+   <layout class="QVBoxLayout" name="verticalLayout_2">
+    <item alignment="Qt::AlignHCenter">
+     <widget class="QLabel" name="label">
+      <property name="text">
+       <string>工具1:根据类别和项目名称提取特征数据生成tfrecord</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_10">
+      <property name="text">
+       <string>提示:筛选条件写入格式为 class_name,spot_id,percent。多个项目数据之间用英文分号隔开</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_12">
+      <property name="text">
+       <string>    class_name 类别名称(不是state_name)要获取全部,填写all</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_13">
+      <property name="text">
+       <string>    spot_id 地点ID,要获取全部,填写0</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_14">
+      <property name="text">
+       <string>    percent 获取数据的百分比(0-1],要获取全部,填写1</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_11">
+      <property name="text">
+       <string>    示例:E1,0,0.2;K10,12,0.3;all,1,1</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_4">
+      <item>
+       <widget class="QLabel" name="label_9">
+        <property name="text">
+         <string>筛选条件:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_9"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_7">
+      <item>
+       <widget class="QLabel" name="label_21">
+        <property name="text">
+         <string>数据保存路径:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QPushButton" name="pushButton_21">
+        <property name="text">
+         <string>打开文件夹</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_21"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <widget class="QPushButton" name="pushButton_1">
+      <property name="text">
+       <string>工具1:生成数据</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_17">
+      <property name="text">
+       <string>--------------------------------------------------------------------------------------</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_26">
+      <property name="text">
+       <string/>
+      </property>
+     </widget>
+    </item>
+   </layout>
+  </widget>
+  <widget class="QWidget" name="">
+   <layout class="QVBoxLayout" name="verticalLayout_3">
+    <item alignment="Qt::AlignHCenter">
+     <widget class="QLabel" name="label_23">
+      <property name="text">
+       <string>工具2:获取全部通用类别生成tfrecord</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_5">
+      <item>
+       <widget class="QLabel" name="label_18">
+        <property name="text">
+         <string>非通用字段:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_18"/>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_19">
+        <property name="text">
+         <string>数据百分比:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_19"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_8">
+      <item>
+       <widget class="QLabel" name="label_24">
+        <property name="text">
+         <string>数据保存路径:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QPushButton" name="pushButton_24">
+        <property name="text">
+         <string>打开文件夹</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_24"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <widget class="QPushButton" name="pushButton_2">
+      <property name="text">
+       <string>工具2:生成数据</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_22">
+      <property name="text">
+       <string>--------------------------------------------------------------------------------------</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_2">
+      <property name="text">
+       <string/>
+      </property>
+     </widget>
+    </item>
+   </layout>
+  </widget>
+  <widget class="QWidget" name="">
+   <layout class="QVBoxLayout" name="verticalLayout_4">
+    <item alignment="Qt::AlignHCenter">
+     <widget class="QLabel" name="label_20">
+      <property name="text">
+       <string>工具3:根据项目文件夹名称获取spot_id</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_6">
+      <item>
+       <widget class="QLabel" name="label_25">
+        <property name="text">
+         <string>项目文件夹名称:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QPushButton" name="pushButton_25">
+        <property name="text">
+         <string>打开文件夹</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_25"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <widget class="QPushButton" name="pushButton_3">
+      <property name="text">
+       <string>工具3:提交</string>
+      </property>
+     </widget>
+    </item>
+   </layout>
+  </widget>
+  <widget class="QWidget" name="">
+   <property name="geometry">
+    <rect>
+     <x>60</x>
+     <y>30</y>
+     <width>518</width>
+     <height>119</height>
+    </rect>
+   </property>
+   <layout class="QVBoxLayout" name="verticalLayout">
+    <item>
+     <widget class="QLabel" name="label_15">
+      <property name="text">
+       <string>数据库配置信息</string>
+      </property>
+     </widget>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_3">
+      <item>
+       <widget class="QLabel" name="label_3">
+        <property name="text">
+         <string>Host:  </string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_3"/>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_4">
+        <property name="text">
+         <string>Port:     </string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_4"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout_2">
+      <item>
+       <widget class="QLabel" name="label_5">
+        <property name="text">
+         <string>User:  </string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_5"/>
+      </item>
+      <item>
+       <widget class="QLabel" name="label_6">
+        <property name="text">
+         <string>Password: </string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_6"/>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <layout class="QHBoxLayout" name="horizontalLayout">
+      <item>
+       <widget class="QLabel" name="label_7">
+        <property name="text">
+         <string>dbName:</string>
+        </property>
+       </widget>
+      </item>
+      <item>
+       <widget class="QLineEdit" name="lineEdit_7"/>
+      </item>
+      <item>
+       <widget class="QPushButton" name="pushButton">
+        <property name="text">
+         <string>数据库连接测试</string>
+        </property>
+       </widget>
+      </item>
+     </layout>
+    </item>
+    <item>
+     <widget class="QLabel" name="label_16">
+      <property name="text">
+       <string>--------------------------------------------------------------------------------------</string>
+      </property>
+     </widget>
+    </item>
+   </layout>
+  </widget>
+ </widget>
+ <resources/>
+ <connections/>
+</ui>

+ 8 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py.py
+# Time       :21.7.29 10:05
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 93 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/check_data_step2.py

@@ -0,0 +1,93 @@
+"""
+# File       : check_data_step2.py
+# Time       :21.5.29 12:26
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:对标注公司标注好的文件进行二次质检,确保xml文件和jpg文件一一对应 step2_检查是否存在漏标的图片
+"""
+import glob
+import shutil
+import sys
+
+from gen_data.utils import pathUtil, fileUtil
+
+class checkFile:
+    def __init__(self, project_path):
+        self.project_path = project_path
+        self.redundant_dir = pathUtil.path_format_join(project_path, "redundant_dir") # 问题xml文件和jpg文件保存路径
+        self.file_names = pathUtil.path_format_join(project_path, "TotalData")  # 图片保存路径
+
+    def diff_check(self, list1, list2, file_type):
+        """
+        检查list1中有而list2中没有的文件,将问题文件移动到redundant_dir文件夹中,并返回问题文件的绝对路径
+        :param list1:列表
+        :param list2:列表
+        :param file_type:文件类型
+        :return:由问题文件路径组成的列表
+        """
+        problem_list = []
+        diff_list = set(list1).difference(list2)
+        for diff in diff_list:
+            pronlem_file_name = diff+file_type
+            pronlem_file_path = pathUtil.path_format_join(self.file_names, pronlem_file_name)
+            move_file_path = pathUtil.path_format_join(self.redundant_dir, pronlem_file_name)
+            problem_list.append(pronlem_file_path)
+            shutil.move(pronlem_file_path, move_file_path)
+        if len(problem_list)>0:
+            if file_type == '.xml':
+                problem_list.extend([{'.xml文件缺少对应的.jpg文件': problem_list}])
+                print('这些.xml文件缺少对应的.jpg文件:%s'% (problem_list))
+            else:
+                problem_list.extend([{'.jpg文件缺少对应的.xml文件': problem_list}])
+                print('这些.jpg文件缺少对应的.xml文件:%s' % (problem_list))
+        return problem_list
+
+
+    def check_file(self, xml_name_list, jpg_name_list):
+        """
+        筛选出xml和jpg无法一一对应的问题文件路径。并组成列表并返回
+        :param xml_name_list: xml文件的列表
+        :param jpg_name_list: jpg文件的列表
+        :return: 问题文件列表
+        """
+        diff1 = self.diff_check(xml_name_list, jpg_name_list, '.xml')
+        diff2 = self.diff_check(jpg_name_list, xml_name_list, '.jpg')
+        problem_list = diff1 + diff2
+        return problem_list
+
+    def main(self):
+
+        xml_name_list = [pathUtil.path_format(file).split('/')[-1].split('.xml')[0] for file in
+                         glob.glob(self.file_names + '/*.xml')]
+        jpg_name_list = [pathUtil.path_format(file).split('/')[-1].split('.jpg')[0] for file in
+                         glob.glob(self.file_names + '/*.jpg')]
+        if len(xml_name_list) + len(jpg_name_list) < 1:
+            mess = '没有找相应的数据,请检查 %s 路径数据' % self.file_names
+            print(mess)
+            return 0, mess
+            # sys.exit(-1)
+        else:
+            pathUtil.mkdir_new(self.redundant_dir)
+            problem_list = self.check_file(xml_name_list, jpg_name_list)
+            if problem_list:
+                fileUtil.writelog(problem_list)
+                print('问题文件的存放地址为:%s' % (self.redundant_dir))
+            else:
+                print('检验完毕,xml文件和jpg文件正常!')
+            print('\n-----------------------------step2完成-----------------------------\n')
+            return 1, None
+
+# if __name__ == "__main__":
+#     checkFile().main()
+
+
+
+
+
+
+
+
+
+
+

+ 154 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/classFeatureImages_step3.py

@@ -0,0 +1,154 @@
+"""
+# File       : classFeatureImages_step3.py
+# Time       :21.5.28 14:35
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:将标注后的图像中的特征图像提取出来后按照类别分类 step3_截取图片中标注后的特征图像,并对其分类处理
+"""
+import math
+import cv2
+import shutil
+import time
+import numpy as np
+from tqdm import tqdm
+from gen_data.utils import coorUtil, profileUtil, pathUtil, fileUtil
+import threading
+
+class classFeatureImages:
+
+    def __init__(self, project_path):
+        self.redundant_dir = pathUtil.path_format_join(project_path, "redundant_data")
+        self.img_dir = pathUtil.path_format_join(project_path, "TotalData") # 图片保存路径
+        self.class_img_dir = pathUtil.path_format_join(project_path, "class_img_dir") # 分类好的图片保存路径
+
+
+    def cv_imread(self, filePath):
+        """
+        读取图像,该方法解决了cv2.imread()不能读取中文路径的问题
+        :param filePath: 文件路径
+        :return:
+        """
+        cv_img = cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), -1)
+        return cv_img
+
+
+    @staticmethod
+    def splitdf(df, num):
+        linenum = math.floor(len(df) / num)
+        pdlist = []
+        for i in range(num):
+            pd1 = df[i * linenum:(i + 1) * linenum]
+            pdlist.append(pd1)
+        #         print(len(pd1))
+        pd1 = df[(num - 1) * linenum:len(df)]
+        pdlist.append(pd1)
+        return pdlist
+
+    def isexists_class(self, df, class_img_dir):
+        """
+        通过['class_name']字段先获取所有图像的类别,然后判断该类别文件夹是否存在,如果不存在则新建该类别文件夹
+        :param df: Dataframe
+        :param class_img_dir:类别文件夹的父目录
+        :return:
+        """
+        group_df = df.groupby('class_name')
+        for k, _ in group_df:
+            class_dir = pathUtil.path_format_join(class_img_dir, str(k))
+            pathUtil.mkdir(class_dir)
+
+
+
+    def gen_class_img_thread(self, csv_df, img_dir, class_img_dir):
+        self.isexists_class(csv_df, class_img_dir)
+        threads = []
+        df_list = self.splitdf(csv_df, 10)
+        for df in df_list:
+            t =threading.Thread(target=self._gen_class_img, args=(df, img_dir, class_img_dir))
+            threads.append(t)
+        for t in threads:
+            t.setDaemon(True)
+            t.start()
+        for t in threads:
+            t.join()
+        time.sleep(10)
+        print('/n标注框总数为:', len(csv_df))
+
+    def _gen_class_img(self, csv_df, img_dir, class_img_dir):
+        """
+        将所有图片中标注出来的特征图片切分出来存放入对应类别的文件夹
+        :param img_dir: 需要读取原始图片的路径
+        :param class_img_dir: 存放分类的特征图片保存路径的父路径,与class_name拼接后生成保存图像文件的绝对路径
+        :return:
+        """
+        errors = []
+        for index, row in tqdm(csv_df.iterrows(),total=len(csv_df), ncols=60):
+            filename, class_name = row["filename"], row["class_name"]
+            image_path = pathUtil.path_format_join(img_dir, filename)
+            error, error_dic = coorUtil.check_coor(image_path, row)
+            if error:
+                old_file_path  = '.'.join(image_path.split('.')[0:-1])
+                new_file = '.'.join(image_path.split('\\')[-1].split('/')[-1].split('.')[0:-1])
+                new_file_path = pathUtil.path_format_join(self.redundant_dir,new_file)
+                try:
+                    shutil.move(old_file_path+'.jpg', new_file_path+'.jpg')
+                    shutil.move(old_file_path+'.xml', new_file_path+'.xml')
+                    errors.extend([error_dic])
+                except:
+                    pass
+                continue
+            xmin, ymin, xmax, ymax = row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+            class_file = pathUtil.path_format_join(class_img_dir, "{}".format(row["class_name"]))
+            image = self.cv_imread(image_path)
+            cropimg = image[int(ymin):int(ymax), int(xmin):int(xmax)]
+            img_path = pathUtil.path_format_join(class_file, filename.split('.jpg')[0] + '_' + str(xmin) + '_' + str(ymin) + '_' + str(xmax) + '_' + str(ymax) + '.jpg')
+            cv2.imwrite(img_path, cropimg)
+        if errors:
+            print('标注图像有问题:', errors)
+            fileUtil.writelog(errors)
+
+    # def gen_class_img(self, img_dir, class_img_dir):
+    #     """
+    #     单线程
+    #     将所有图片中标注出来的特征图片切分出来存放入对应类别的文件夹
+    #     :param img_dir: 需要读取原始图片的路径
+    #     :param class_img_dir: 存放分类的特征图片保存路径的父路径,与class_name拼接后生成保存图像文件的绝对路径
+    #     :return:
+    #     """
+    #
+    #     csv_df = profileUtil.xmlUtil().xml_parse(self.img_dir)
+    #     self.isexists_class(csv_df, class_img_dir)
+    #     errors = []
+    #     for index, row in tqdm(csv_df.iterrows(),total=len(csv_df), ncols=80):
+    #         filename, class_name = row["filename"], row["class_name"]
+    #         image_path = pathUtil.path_format_join(img_dir, filename)
+    #         error, error_dic = coorUtil.check_coor(image_path, row)
+    #         if error:
+    #             old_file_path  = '.'.join(image_path.split('.')[0:-1])
+    #             new_file = '.'.join(image_path.split('\\')[-1].split('/')[-1].split('.')[0:-1])
+    #             new_file_path = pathUtil.path_format_join(self.redundant_dir,new_file)
+    #             try:
+    #                 shutil.move(old_file_path+'.jpg', new_file_path+'.jpg')
+    #                 shutil.move(old_file_path+'.xml', new_file_path+'.xml')
+    #                 errors.extend([error_dic])
+    #             except:
+    #                 pass
+    #             continue
+    #         xmin, ymin, xmax, ymax = row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+    #         class_file = pathUtil.path_format_join(class_img_dir, "{}".format(row["class_name"]))
+    #         image = self.cv_imread(image_path)
+    #         cropimg = image[int(ymin):int(ymax), int(xmin):int(xmax)]
+    #         img_path = pathUtil.path_format_join(class_file, filename.split('.jpg')[0] + '_' + str(xmin) + '_' + str(ymin) + '_' + str(xmax) + '_' + str(ymax) + '.jpg')
+    #         cv2.imwrite(img_path, cropimg)
+    #     if errors:
+    #         print('标注图像有问题:', errors)
+    #         fileUtil.writelog(errors)
+    def main(self):
+        pathUtil.mkdir_new(self.class_img_dir)
+        csv_df = profileUtil.xmlUtil().xml_parse(self.img_dir)
+        self.gen_class_img_thread(csv_df, self.img_dir,self.class_img_dir)
+        print('\n-----------------------------step3完成-----------------------------\n')
+
+
+# if __name__ == '__main__':
+#     classFeatureImages().main()

+ 39 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/genAnn_step5.py

@@ -0,0 +1,39 @@
+"""
+# File       : genAnn_step5.py
+# Time       :21.5.31 9:10
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:生成图片的具体信息:文件路径、特征图像的坐标等关键信息 step5_生成图片信息文件
+"""
+import time
+from gen_data.utils import profileUtil, pathUtil
+
+class ganAnn:
+    def __init__(self, project_path):
+        self.data_dir = pathUtil.path_format_join(project_path, "TotalData")
+        self.label_csv_dir = pathUtil.path_format_join(project_path, "label_csv")
+        self.label_txt_dir = pathUtil.path_format_join(project_path, "label_txt")
+        self.class_index_json = pathUtil.path_format_join(project_path, "class_index/ob_classes.json")
+        self.total_data_txt = pathUtil.path_format_join(self.label_txt_dir, "total_data.txt")
+        self.total_data_yolo_txt = pathUtil.path_format_join(project_path, "yolo_txt/total_data")
+        self.total_data_csv = pathUtil.path_format_join(self.label_csv_dir, "total_data.csv")
+        self.class_index_json = pathUtil.path_format_join(project_path, "class_index/ob_classes.json")
+        self.flag = "yolo"
+
+    def main(self):
+        pathUtil.mkdir_new(self.label_csv_dir)
+        pathUtil.mkdir_new(self.label_txt_dir)
+        pathUtil.mkdir_new(self.total_data_yolo_txt)
+        xml_util = profileUtil.xmlUtil()
+        if self.flag.lower() == 'yolo':
+            xml_util.xml_to_yolo_txt(self.total_data_yolo_txt, self.data_dir, self.class_index_json)
+            xml_util.xml_to_csv(self.total_data_csv, self.data_dir)
+        else:
+            xml_util.xml_to_txt(self.total_data_txt, self.data_dir, self.class_index_json)
+            xml_util.xml_to_csv(self.total_data_csv, self.data_dir)
+        print('\n-----------------------------step5完成-----------------------------\n')
+        time.sleep(0.5)
+
+# if __name__ == "__main__":
+#     ganAnn(r'Z:\data2\fengyang\sunwin\test').main()

+ 58 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/gen_class_index_step4.py

@@ -0,0 +1,58 @@
+"""
+# File       : gen_class_index_step4.py
+# Time       :21.5.29 10:10
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:生成图片的所有分类信息 step4_生成类别信息文件
+"""
+import json
+from gen_data.utils import profileUtil, pathUtil
+
+class genClassIndex:
+    def __init__(self, project_path):
+        self.data_dir = pathUtil.path_format_join(project_path, "TotalData")  # 存放xml和jpg的文件夹
+        self.index_dir = pathUtil.path_format_join(project_path, "class_index")  # 存放标注种类的文件夹
+        self.class_index_txt_path = pathUtil.path_format_join(project_path, "class_index/ob_classes.txt") # 存放标注种类的txt文件
+        self.class_index_json_path = pathUtil.path_format_join(project_path, "class_index/ob_classes.json") # 存放标注种类的json文件
+
+    def gen_class_index_json_txt(self, xml_df, class_index_json_path, class_index_txt_path):
+        """
+        将分好的class_name分别写入json文件和txt文件
+        :param data_dir: 存放xml和jpg的文件夹
+        :param class_index_json_path: 存放标注种类的json文件夹
+        :param class_index_txt_path: 存放标注种类的txt文件夹
+        :return:
+        """
+        class_index_dict = dict()
+        class_name_set = set(xml_df["class_name"].values)
+        class_name_list = list(class_name_set)
+        class_name_list.sort()
+        for index, class_name in enumerate(class_name_list):
+            class_index_dict[class_name] = index
+
+        # save json file
+        with open(class_index_json_path, "w") as json_file:
+            json_file.write(json.dumps(class_index_dict, indent=4))
+        print("总类别数: %d" % (len(class_name_list)))
+        print("写入 %s 完成"%(class_index_json_path))
+
+
+        # save txt file
+        with open(class_index_txt_path, "w") as txt_file:
+            index_class_dict = {value: key for key, value in class_index_dict.items()}
+            for i in range(len(index_class_dict)):
+                txt_file.write(index_class_dict[i])
+                txt_file.write("\n")
+
+        print("写入 %s 完成"%(class_index_txt_path))
+
+    def main(self):
+        pathUtil.mkdir_new(self.index_dir)
+        xml_df = profileUtil.xmlUtil().xml_parse(self.data_dir)
+        self.gen_class_index_json_txt(xml_df, self.class_index_json_path, self.class_index_txt_path)
+        print('\n-----------------------------step4完成-----------------------------\n')
+
+
+# if __name__ == "__main__":
+#     genClassIndex().main()

+ 138 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/gen_pb_tfrecord_step7.py

@@ -0,0 +1,138 @@
+"""
+# File       : gen_pb_tfrecord_step7.py
+# Time       :21.6.4 10:10
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:生成pbtxt文件和Tfrecord文件 step7_生成特定的文件
+"""
+import os
+import io
+import json
+import pandas as pd
+import tensorflow.compat.v1 as tf
+from PIL import Image
+from collections import namedtuple
+from research.object_detection.utils import dataset_util
+from research.object_detection.utils import label_map_util
+from gen_data.utils import pathUtil
+
+class genPb:
+    def __init__(self, project_path):
+        self.class_json_path = pathUtil.path_format_join(project_path, "class_index/ob_classes.json")
+        self.pbtxt_path = pathUtil.path_format_join(project_path, "ob.pbtxt")
+
+    def gen_pbtxt(self):
+        with open(self.class_json_path, 'r') as json_file:
+            json_dict = json.load(json_file)
+        json_len = len(json_dict)
+        total_list = ["" for i in range(json_len)]
+        for key, value in json_dict.items():
+            line_content_list = []
+            line_content_list.append("item {\n")
+            line_content_list.append("  id: {}\n".format(value + 1))
+            line_content_list.append("  name: '{}'\n".format(key))
+            line_content_list.append("}\n")
+            line_content_list.append("\n")
+            fill_content = "".join(line_content_list)
+            total_list[value] = fill_content
+
+        with open(self.pbtxt_path, 'w') as pbtxt_file:
+            for i in total_list:
+                pbtxt_file.write(i)
+
+    def main(self):
+        self.gen_pbtxt()
+        print('成功创建%s' % self.pbtxt_path)
+
+class genTfrecord:
+    def __init__(self, project_path):
+        self.tfrecord = pathUtil.path_format_join(project_path, "tf_record")
+
+        flags = tf.app.flags
+        flags.DEFINE_string("val_csv_input", pathUtil.path_format_join(project_path, "label_csv/val_data.csv"),
+                            "Path to the CSV input")
+        flags.DEFINE_string("images_input", pathUtil.path_format_join(project_path, "TotalData"),
+                            "Path to the images input")
+        flags.DEFINE_string("val_output_path",
+                            pathUtil.path_format_join(project_path, "tf_record/val_data.record"),
+                            "Path to output TFRecord")
+        flags.DEFINE_string("label_map_path", pathUtil.path_format_join(project_path, "ob.pbtxt"),
+                            "Path to label map proto")
+        flags.DEFINE_string("train_csv_input",
+                            pathUtil.path_format_join(project_path, "label_csv/train_data.csv"),
+                            "Path to the CSV input")
+        flags.DEFINE_string("train_output_path",
+                            pathUtil.path_format_join(project_path, "tf_record/train_data.record"),
+                            "Path to output TFRecord")
+        self.FLAGS = flags.FLAGS
+
+
+    def split(self, df, group):
+        data = namedtuple("data", ["filename", "object"])
+        gb = df.groupby(group)
+        return [data(filename, gb.get_group(x)) for filename, x in
+                zip(gb.groups.keys(), gb.groups)]
+
+
+    def create_tf_example(self, group, label_map_dict, images_path):
+        with tf.gfile.GFile(os.path.join(
+                images_path, "{}".format(group.filename)), "rb") as fid:
+            encoded_jpg = fid.read()
+        encoded_jpg_io = io.BytesIO(encoded_jpg)
+        image = Image.open(encoded_jpg_io)
+        width, height = image.size
+
+        filename = group.filename.encode("utf8")
+        image_format = b"jpg"
+        xmins, xmaxs, ymins, ymaxs, classes_text, classes = [], [], [], [], [], []
+        for index, row in group.object.iterrows():
+            xmins.append(row["xmin"] / width)
+            xmaxs.append(row["xmax"] / width)
+            ymins.append(row["ymin"] / height)
+            ymaxs.append(row["ymax"] / height)
+            classes_text.append(str(row['class_name']).encode("utf8"))
+            classes.append(label_map_dict[str(row['class_name'])])
+
+        tf_example = tf.train.Example(features=tf.train.Features(feature={
+            "image/height": dataset_util.int64_feature(height),
+            "image/width": dataset_util.int64_feature(width),
+            "image/filename": dataset_util.bytes_feature(filename),
+            "image/source_id": dataset_util.bytes_feature(filename),
+            "image/encoded": dataset_util.bytes_feature(encoded_jpg),
+            "image/format": dataset_util.bytes_feature(image_format),
+            "image/object/bbox/xmin": dataset_util.float_list_feature(xmins),
+            "image/object/bbox/xmax": dataset_util.float_list_feature(xmaxs),
+            "image/object/bbox/ymin": dataset_util.float_list_feature(ymins),
+            "image/object/bbox/ymax": dataset_util.float_list_feature(ymaxs),
+            "image/object/class/text": dataset_util.bytes_list_feature(classes_text),
+            "image/object/class/label": dataset_util.int64_list_feature(classes),
+        }))
+        return tf_example
+
+
+    def gen_tfrecord(self, val_output_path, val_csv_input):
+        # generate val_tfrecord
+        writer = tf.python_io.TFRecordWriter(val_output_path)
+        label_map_dict = label_map_util.get_label_map_dict(self.FLAGS.label_map_path)
+        images_path = self.FLAGS.images_input
+        examples = pd.read_csv(val_csv_input)
+        grouped = self.split(examples, "filename")
+        for group in grouped:
+            tf_example = self.create_tf_example(group, label_map_dict, images_path)
+            writer.write(tf_example.SerializeToString())
+
+        writer.close()
+        print("成功创建 %s" % val_output_path)
+
+
+    def main(self):
+        pathUtil.mkdir_new(self.tfrecord)
+        self.gen_tfrecord(self.FLAGS.val_output_path, self.FLAGS.val_csv_input)
+        self.gen_tfrecord(self.FLAGS.train_output_path, self.FLAGS.train_csv_input)
+        print('\n-----------------------------step7完成-----------------------------\n')
+
+
+# if __name__ == "__main__":
+#     genPb(r'Z:\data2\fengyang\sunwin\test').main()
+#     genTfrecord(r'Z:\data2\fengyang\sunwin\test').main()

+ 232 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/splitTrainVal_step6.py

@@ -0,0 +1,232 @@
+"""
+# File       : splitTrainVal_step6.py
+# Time       :21.6.1 13:49
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:在根据类别信息文件对图片进行分类后的基础上进行训练集和测试集的划分操作 step6_训练集和验证集的划分
+"""
+import os
+import math
+import sys
+import shutil
+import pandas as pd
+from tqdm import tqdm
+from gen_data.utils import profileUtil, pathUtil
+from sklearn.model_selection import train_test_split
+
+class splitTrainVal:
+    def __init__(self, project_path):
+        self.project_name = pathUtil.path_format(project_path).split('/')[-1]
+        self.data_dir = pathUtil.path_format_join(project_path, "TotalData")
+        label_csv_dir = pathUtil.path_format_join(project_path, "label_csv")
+        label_txt_dir = pathUtil.path_format_join(project_path, "label_txt")
+        self.train_dir = pathUtil.path_format_join(project_path, "train_data")
+        self.val_dir = pathUtil.path_format_join(project_path, "val_data")
+        self.class_index_json = pathUtil.path_format_join(project_path, "class_index/ob_classes.json")
+        self.total_data_csv = pathUtil.path_format_join(label_csv_dir, "total_data.csv")
+        self.split_ratio = 0.15
+
+        self.flag = "yolo"
+        self.train_data_txt = pathUtil.path_format_join(label_txt_dir, "train_data.txt")
+        self.train_data_yolo_txt = pathUtil.path_format_join(project_path, "yolo_txt/train_data")
+        self.train_data_csv = pathUtil.path_format_join(label_csv_dir, "train_data.csv")
+        self.val_data_txt = pathUtil.path_format_join(label_txt_dir, "val_data.txt")
+        self.val_data_yolo_txt = pathUtil.path_format_join(project_path, "yolo_txt/val_data")
+        self.val_data_csv = pathUtil.path_format_join(label_csv_dir, "val_data.csv")
+        self.class_index_txt = pathUtil.path_format_join(project_path, "class_index/ob_classes.txt")
+
+
+
+    def split_train_val(self):
+        """
+        切分训练接和测试集
+        首先根据文件名进行分组,然后确保同一个图片中不会出现相同的类别。此时遍历所有图片,一旦出现新的分类就加入字典中,并动态的将图片分到不同的类别下,进而保障了各个类别下图片分布的相对均匀
+
+        :return:
+        """
+        pathUtil.mkdir(self.train_dir)
+        pathUtil.mkdir(self.val_dir)
+        total_df = pd.read_csv(self.total_data_csv)
+        # 由所有分类组成的字典
+        class_dict = dict()
+        # 根据图像文件名进行分组,并遍历
+        for key, values in tqdm(total_df.groupby('filename'),total=len(total_df.groupby('filename')), ncols=80):
+            flag = True
+            key = key.split('.jpg')[0]
+            # 去除该图像文件中重复的类别(确保每个分类在一张图片中只有一个)
+            values.drop_duplicates(subset=['class_name'],keep='first',inplace=True)
+            # 由该图像文件中包含的所有分类组成的列表
+            class_name = []
+            # 遍历该图像中的所有分类
+            for k, v in values.iterrows():
+                #分类字典中没有这个K
+                if class_dict.get(v['class_name']) == None:
+                    path_list = [key]
+                    class_dict[v['class_name']] = path_list
+                    # BUG优化: flag的作用是为了防止第二张图像被同时分到两个类别中的情况发生
+                    flag = False
+                    break
+                else:
+                    class_name.append(v['class_name'])
+            # 第一张图片按照其一个个类别划分
+            if len(class_name) == 1 and flag:
+                class_dict[class_name[0]].append(key)
+            # 后面的图片按照分类字典中各个分类所划分的文件数量最少者得的思路进行
+            elif len(class_name) > 1 and flag:
+                list_len = math.inf
+                min_class_key = None
+                for class_k in class_name:
+
+                    if list_len > len(class_dict[class_k]):
+                        list_len = len(class_dict[class_k])
+                        min_class_key = class_k
+                class_dict[min_class_key].append(key)
+
+            else:
+                continue
+        train_data, val_data = [], []
+        train_distribute = {}
+        val_distribute = {}
+        for key in class_dict:
+            clas = class_dict.get(key)
+            try:
+                split_ratio = float(self.split_ratio)
+            except:
+                print('conf文件中的split_ratio字段无法转换成数字,请检查')
+                sys.exit(1)
+            train_d, val_d = train_test_split(clas, random_state=2020, test_size=split_ratio, shuffle=True)
+            train_distribute[key] = len(train_d)
+            val_distribute[key] = len(val_d)
+            train_data.extend(train_d)
+            val_data.extend(val_d)
+        print('数据分布如下:')
+        print('train数据集:%s,val数据集:%s \n'%(train_distribute, val_distribute))
+        return train_data, val_data
+
+
+    def total_copy_TrainVal(self, path_data, total_path, train_val_path):
+        """
+        拷贝jpg、xml文件到相应的train/val数据集文件夹下
+        :param path_data: 被划分的train/val数据集合
+        :param total_path: 总数据文件夹路径
+        :param train_val_path: 划分的train/val数据文件夹路径
+        :return:
+        """
+        for data in tqdm(path_data, total=len(path_data), ncols=80):
+            xml_path = pathUtil.path_format_join(total_path, str(data+'.xml'))
+            jpg_path = pathUtil.path_format_join(total_path, str(data+'.jpg'))
+            xml_new_path = pathUtil.path_format_join(train_val_path, str(str(xml_path).split('/')[-1]))
+            jpg_new_path = pathUtil.path_format_join(train_val_path, str(str(jpg_path).split('/')[-1]))
+            shutil.copyfile(xml_path, xml_new_path)
+            shutil.copyfile(jpg_path, jpg_new_path)
+
+    def copy_yoloTxt_TrainVal(self, path_data, total_path, train_val_path, yolo_path):
+        """
+        拷贝jpg、xml文件到相应的train/val数据集文件夹下
+        :param path_data: 被划分的train/val数据集合
+        :param total_path: 总数据文件夹路径
+        :param train_val_path: 划分的train/val数据文件夹路径
+        :return:
+        """
+        yolo_image_path = pathUtil.path_format_join(yolo_path, 'images')
+        pathUtil.mkdir(yolo_image_path)
+        for data in tqdm(path_data, total=len(path_data), ncols=80):
+            xml_path = pathUtil.path_format_join(total_path, str(data + '.xml'))
+            jpg_path = pathUtil.path_format_join(total_path, str(data + '.jpg'))
+            xml_new_path = pathUtil.path_format_join(train_val_path, str(str(xml_path).split('/')[-1]))
+            jpg_new_path = pathUtil.path_format_join(train_val_path, str(str(jpg_path).split('/')[-1]))
+            jpg_yolo_path = pathUtil.path_format_join(yolo_image_path, str(str(jpg_path).split('/')[-1]))
+            shutil.copyfile(xml_path, xml_new_path)
+            shutil.copyfile(jpg_path, jpg_new_path)
+            shutil.copyfile(jpg_path, jpg_yolo_path)
+
+    def xml2_txtCsv(self, data_txt, data_csv, data_dir):
+        """
+        保存train/val数据信息到txt和csv文件中
+        :param data_txt: txt文件
+        :param data_csv: csv文件
+        :return:
+        """
+        xml_ = profileUtil.xmlUtil()
+        xml_.xml_to_txt(data_txt, data_dir, self.class_index_json)
+        xml_.xml_to_csv(data_csv, data_dir)
+
+    def xml2_yolotxtCsv(self, data_txt, data_csv, data_dir):
+        """
+        保存train/val数据信息到txt和csv文件中
+        :param data_txt: txt文件
+        :param data_csv: csv文件
+        :return:
+        """
+        xml_ = profileUtil.xmlUtil()
+        xml_.xml_to_yolo_txt(data_txt, data_dir, self.class_index_json)
+        xml_.xml_to_csv(data_csv, data_dir)
+
+    def yaml_write(self):
+        """
+        train: /data/humaocheng/sunwin_project/data/kuye/yolo_txt/train_data
+        val: /data/humaocheng/sunwin_project/data/kuye/yolo_txt/val_data
+        # number of classes
+        nc: 6
+        # class names
+        names : ['c2', 'a1', 'b2', 'b1', 'c1', 'a2']
+
+
+        写入yolo_txt/xxx.yaml文件
+        :return:
+        """
+        yaml_train = 'train:' + self.train_data_yolo_txt
+        yaml_val = 'val:' + self.val_data_yolo_txt
+        with open(self.class_index_txt, 'r', encoding='utf8') as f:
+            names = [name.strip() for name in f.readlines()]
+        yaml_nc = 'nc:' + str(len(names))
+        yaml_names = 'names:' + str(names)
+
+        yaml = pathUtil.path_format_join('/'.join(self.train_data_yolo_txt.split('/')[0:-1]), self.project_name + '.yaml')
+        with open(yaml, 'w', encoding='utf8') as f:
+            f.write(yaml_train)
+            f.write('\n')
+            f.write(yaml_val)
+            f.write('\n')
+            f.write(yaml_nc)
+            f.write('\n')
+            f.write(yaml_names)
+            f.close()
+
+
+    def main(self):
+        if self.flag.lower() == 'yolo':
+            pathUtil.mkdir_new(self.train_dir)
+            pathUtil.mkdir_new(self.val_dir)
+            pathUtil.mkdir_new(self.train_data_yolo_txt)
+            pathUtil.mkdir_new(self.val_data_yolo_txt)
+            train_data, val_data = self.split_train_val()
+            self.copy_yoloTxt_TrainVal(train_data, self.data_dir, self.train_dir, self.train_data_yolo_txt)
+            self.copy_yoloTxt_TrainVal(val_data, self.data_dir, self.val_dir, self.val_data_yolo_txt)
+            self.yaml_write()
+            self.xml2_yolotxtCsv(self.train_data_yolo_txt, self.train_data_csv, self.train_dir)
+            self.xml2_yolotxtCsv(self.val_data_yolo_txt, self.val_data_csv, self.val_dir)
+
+        else:
+            pathUtil.mkdir_new(self.train_dir)
+            pathUtil.mkdir_new(self.val_dir)
+            train_data, val_data = self.split_train_val()
+            self.total_copy_TrainVal(train_data, self.data_dir, self.train_dir)
+            self.total_copy_TrainVal(val_data, self.data_dir, self.val_dir)
+
+            self.xml2_txtCsv(self.train_data_txt,self.train_data_csv, self.train_dir)
+            self.xml2_txtCsv(self.val_data_txt, self.val_data_csv, self.val_dir)
+        print('train_data val_data 分割完成')
+        print('\n-----------------------------step6完成-----------------------------\n')
+
+# if __name__ == "__main__":
+#     splitTrainVal(r'Z:\data2\fengyang\sunwin\test').main()
+#
+
+
+
+
+
+
+

+ 21 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/test.py

@@ -0,0 +1,21 @@
+"""
+# File       : test.py
+# Time       :21.7.30 18:29
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+
+import glob
+import os
+import shutil
+
+for i in glob.glob(r"Z:/data2/fengyang/sunwin/code/data_manage/*.py"):
+    print(i)
+
+    new_path =  r"Z:/data2/fengyang/sunwin/code/data_manage2" +'/'+ i.replace('\\', '/').split('/')[-1]
+    print(new_path)
+    # a = 'copy %s %s'%(i, new_path)
+    # os.system(a)
+    shutil.copy(i, new_path)

+ 8 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py
+# Time       :21.7.30 9:28
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 75 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/coorUtil.py

@@ -0,0 +1,75 @@
+"""
+# File       : coorUtil.py
+# Time       :21.6.8 17:28
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:检查标注框在原图片中的位置是否出现超越背景图片总长宽的情况
+"""
+
+import cv2
+
+def check_coor(image_path, row):
+    error_dic = {}
+    width, height, xmin, ymin, xmax, ymax = row["width"], row["height"], row["xmin"], row["ymin"], row["xmax"], row["ymax"]
+    img = cv2.imread(image_path)
+    error = False
+    if type(img) == type(None):
+        error = True
+        code = str('Could not read image')
+        error_dic[image_path] = code
+        return error, error_dic
+
+    org_height, org_width = img.shape[:2]
+
+    if org_width != width:
+        error = True
+        code = str('Width mismatch for image: ' + str(width) + '!=' + str(org_width))
+        error_dic[image_path] = code
+
+    if org_height != height:
+        error = True
+        code = str('Height mismatch for image: ' + str(height) + '!=' + str(org_height))
+        error_dic[image_path] = code
+
+    if xmin > org_width:
+        error = True
+        code = str('XMIN > org_width for file')
+        error_dic[image_path] = code
+
+    if xmin <= 0:
+        error = True
+        code = str('XMIN < 0 for file')
+        error_dic[image_path] = code
+
+    if xmax > org_width:
+        error = True
+        code = str('XMAX > org_width for file')
+        error_dic[image_path] = code
+
+    if ymin > org_height:
+        error = True
+        code = str('YMIN > org_height for file')
+        error_dic[image_path] = code
+
+    if ymin <= 0:
+        error = True
+        code = str('YMIN < 0 for file')
+        error_dic[image_path] = code
+
+    if ymax > org_height:
+        error = True
+        code = str('YMAX > org_height for file')
+        error_dic[image_path] = code
+
+    if xmin >= xmax:
+        error = True
+        code = str('xmin >= xmax for file')
+        error_dic[image_path] = code
+
+    if ymin >= ymax:
+        error = True
+        code = str('ymin >= ymax for file')
+        error_dic[image_path] = code
+
+    return error, error_dic

+ 97 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/fileUtil.py

@@ -0,0 +1,97 @@
+"""
+# File       : fileUtil.py
+# Time       :21.5.25 18:40
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:文件筛选工具类
+"""
+import json
+import time
+import chardet
+import glob
+import pandas as pd
+from gen_data.utils import pathUtil
+
+class ReadFile:
+    """
+    定义读取文件类型
+    label_csv: label_csv下data数据,返回dataframe ['filename', 'filepath', 'width', 'height', 'depth', 'class_name','class_id', 'xmin', 'ymin', 'xmax', 'ymax']
+    label_txt: label_txt下data数据,返回dataframe ['filename', 'class_id', 'xmin', 'ymin', 'xmax', 'ymax']
+    class_json: class_index下label数据,返回class_dict字典 (key: class_name, value: class_id)
+    """
+
+    def read_json_label(self, class_index_json):
+        with open(class_index_json, "r") as f:
+            class_dict = json.load(f)
+        return class_dict
+
+
+    def read_txt_data(self, txt_dir, total_data_txt):
+        txt_list = []
+        with open(pathUtil.path_format_join(txt_dir, total_data_txt), "r") as f:
+            for line in f.readlines():
+                line = line.strip('\n').split(";")
+                filename = line[0].split("/")[1]
+                bbox = line[1:]
+                for member in bbox:
+                    member = member.split(",")
+                    class_id = member[-1]
+                    x_min,y_min,x_max,y_max = member[0],member[1],member[2],member[3]
+                    value = (filename, class_id, x_min, y_min, x_max, y_max)
+                    txt_list.append(value)
+
+        column_name = ["filename", "class_id", "xmin", "ymin", "xmax", "ymax"]
+        txt_df = pd.DataFrame(txt_list, columns=column_name)
+        return txt_df
+
+
+    def read_csv_data(self, csv_dir, total_data_csv):
+        """
+        读取cvs文件数据
+        :param csv_dir:csv文件夹路径
+        :param total_data_csv:csv文件
+        :return:
+        """
+        csv_df = pd.read_csv(pathUtil.path_format_join(csv_dir, total_data_csv), encoding='utf8')
+        return csv_df
+
+
+
+
+
+def extension_filter(base, extension_str):
+    """
+    提取当前目录及子目录下特定格式的文件,并返回其绝对路径
+
+    :param base: 当前目录
+    :param extension_str: 从conf文件中获取的文件扩展名
+    :return: 筛选后得到文件绝对路径的list
+    """
+    extension = extension_str.split(',')
+    fullname_list = []
+    for ex in extension:
+        ex = ex.strip() if ex.strip().startswith('.') else '.' + ex.strip()  # 扩展名补全
+        ex_list = glob.glob(base + '/**/*' + ex, recursive=True)
+        fullname_list.extend(ex_list)
+    return fullname_list
+
+
+def detectCode(path):
+    """
+    获取文本的编码格式
+    :param path:
+    :return:
+    """
+    with open(path, 'rb') as file:
+        data = file.read(1000)
+        dicts = chardet.detect(data)
+    return dicts["encoding"]
+
+def writelog(data):
+    now = time.strftime('%Y-%m-%d %H:%M', time.localtime(time.time()))
+    with open('../error_log.txt', 'a') as f:
+        f.write("===============================%s======================================\n"%(now))
+        f.writelines(str(data))
+        f.write("\n")
+

+ 32 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/pathUtil.py

@@ -0,0 +1,32 @@
+"""
+# File       : pathUtil.py
+# Time       :21.5.25 18:13
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:路径操作工具
+"""
+import os
+import shutil
+from gen_data.utils import strUtil
+
+def mkdir(new_folder):
+    if not os.path.exists(new_folder):
+        os.makedirs(new_folder)
+
+def mkdir_new(new_folder):
+    if os.path.exists(new_folder):
+        shutil.rmtree(new_folder, ignore_errors=True)
+    os.makedirs(new_folder)
+
+def path_format(path_str):
+    path = strUtil.profile2str(path_str.replace('\\','/'))
+    if str(path).endswith('/'):
+        return '/'.join(path[0:-1])
+    else:
+        return path
+
+def path_format_join(path_str1, path_str2):
+    return os.path.join(path_format(path_str1), path_format(path_str2)).replace('\\','/')
+
+

+ 245 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/profileUtil.py

@@ -0,0 +1,245 @@
+"""
+# File       : profileUtil.py
+# Time       :21.5.25 16:16
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:配置文件操作工具
+"""
+
+import glob
+import os
+from xml.dom import minidom
+import chardet
+import pandas as pd
+import xml.etree.ElementTree as ET
+from configparser import ConfigParser
+from gen_data.utils import fileUtil, pathUtil
+
+class xmlUtil:
+    """
+    xml文件工具类
+    """
+
+    def xml_parse(self, data_dir, find='filename'):
+        """
+
+        :param data_dir: xml file path
+        :return: dataframe (filename, path, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+        """
+        error_xml_list = []
+        xml_list = []
+        for xml_file in glob.glob(data_dir + '/*.xml'):
+            xml_file_path = pathUtil.path_format(xml_file)
+            try:
+                tree = ET.parse(xml_file_path)
+            except:
+                error_xml_list.append(xml_file_path)
+                continue
+            root = tree.getroot()
+            filename = root.find(find).text
+            width, height, depth = int(root.find('size')[0].text), int(root.find('size')[1].text), int(
+                root.find('size')[2].text)
+            for member in root.findall('object'):
+                class_name = member[0].text.upper()
+                x_min, y_min, x_max, y_max = int(member[4][0].text), int(member[4][1].text), int(
+                    member[4][2].text), int(member[4][3].text)
+                value = (filename, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+                xml_list.append(value)
+
+
+        column_name = ['filename', 'width', 'height','depth' ,'class_name', 'xmin', 'ymin', 'xmax', 'ymax']
+        xml_df = pd.DataFrame(xml_list, columns=column_name)
+        if error_xml_list:
+            fileUtil.writelog(error_xml_list)
+            print('解析错误的xml文件有:%s'%(error_xml_list))
+        return xml_df
+
+    def xml_to_csv(self, total_data_csv, path):
+        """
+        将xml转换成csv文件
+
+        :return:
+        """
+        xml_df = self.xml_parse(path)
+        xml_df.to_csv(total_data_csv, index=False)
+        print("%s 写入 %d 行"%(total_data_csv, xml_df.shape[0]))
+
+    def xml_to_txt(self, total_data_txt, path, class_index_json):
+        """
+        将xml文件转换成txt文件
+        arrange label file from xml files, which is formatted as txt. For Example:
+
+        image_full_path [space] [x_min, y_min, x_max, y_max, class_index [space]],
+
+        Like:
+        /data/object_detection/keras-yolo3-master/data/label_data/K2_112647_540.jpg 456,1,516,104,4 662,1,708,102,4 457,229,519,403,4 664,231,711,397,4 852,227,917,401,4 1038,223,1121,396,4 1199,204,1280,417,9
+
+
+        @param data_dir: the folder is to save images and annotations
+        @param txt_dir: the path for saving annotations
+        @param class_index_json: this param is dictionary and key is class name, value is class index
+        @return:
+        """
+        label_content_list = []
+        class_dict = fileUtil.ReadFile().read_json_label(class_index_json)
+
+        xml_df = self.xml_parse(path)
+        group_df = xml_df.groupby('filename')
+        for k, values in group_df:
+            value = ''
+            for id, v in values.iterrows():
+                class_name = str(class_dict[v['class_name']])
+                ymax, xmax, ymin, xmin = str(v['ymax']), str(v['xmax']), str(v['ymin']), str(v['xmin'])
+                single = ','.join([xmin, ymin, xmax, ymax, class_name])
+                value ="{};{}".format(value, single)
+            single_obj_str = "{}{}".format(k, value)
+            label_content_list.append(single_obj_str)
+
+        # write total_label.txt
+        count = 0
+        with open(total_data_txt, "w") as label_txt:
+            for index, label_content in enumerate(label_content_list):
+                label_txt.write(label_content)
+                if index < len(label_content_list) - 1:
+                    label_txt.write("\n")
+                count += 1
+        print("%s 写入 %d 行 "%(total_data_txt, count))
+
+    def yolo5_normalized(self, df):
+        # str(v['ymax']), str(v['xmax']), str(v['ymin']), str(v['xmin'])/ df['height']
+        x = abs(df['xmax'] - df['xmin'])
+        y = abs(df['ymax'] - df['ymin'])
+        x_center = (x / 2 + df['xmin']) / df['width']
+        y_center = (y / 2 + df['ymin']) / df['height']
+        width = x / df['width']
+        height = y / df['height']
+
+        return x_center, y_center, width, height
+
+    def xml_to_yolo_txt(self, yolo_txt, path, class_index_json):
+        """
+        将xml文件转换成txt文件
+        arrange label file from xml files, which is formatted as txt. For Example:
+
+        image_full_path [space] [x_min, y_min, x_max, y_max, class_index [space]],
+
+        Like:
+        /data/object_detection/keras-yolo3-master/data/label_data/K2_112647_540.jpg 456,1,516,104,4 662,1,708,102,4 457,229,519,403,4 664,231,711,397,4 852,227,917,401,4 1038,223,1121,396,4 1199,204,1280,417,9
+
+
+        @param data_dir: the folder is to save images and annotations
+        @param txt_dir: the path for saving annotations
+        @param class_index_json: this param is dictionary and key is class name, value is class index
+        @return:
+        """
+        class_dict = fileUtil.ReadFile().read_json_label(class_index_json)
+
+        xml_df = self.xml_parse(path)
+        group_df = xml_df.groupby('filename')
+        count = 0
+        for k, values in group_df:
+            label_content_list = []
+            for id, v in values.iterrows():
+                class_name = str(class_dict[v['class_name']])
+                x_center, y_center, width, height = self.yolo5_normalized(v)
+                single = ','.join([class_name, str(x_center), str(y_center), str(width), str(height)])
+                label_content_list.extend([single])
+
+            file_ = '.'.join(str(k).split('.')[0:-1])
+            file_name = pathUtil.path_format(file_ + '.txt')
+            yolo_txt_path = pathUtil.path_format_join(yolo_txt, 'labels')
+            pathUtil.mkdir(yolo_txt_path)
+            txt_path = pathUtil.path_format_join(yolo_txt_path, file_name)
+
+            with open(txt_path, "w+") as label_txt:
+                for index, label_content in enumerate(label_content_list):
+                    label_txt.write(label_content)
+                    if index < len(label_content_list) - 1:
+                        label_txt.write("\n")
+            count += 1
+        print("%s 写入 %d 个txt文件 "%(yolo_txt, count))
+    @staticmethod
+    def write_xml(img, df_t, whd_list, total_data_dir):
+        """
+        生成xml文件,写入数据后保存
+
+        :param img:添加过特征图像后新图片的路径
+        :param df_t:该图片中特征图像的坐标信息
+        :param whd_list:新图片的长宽信息
+        :param total_data_dir:最终保存xml数据的路径
+        :return:
+        """
+        filename = img.split('/')[-1]
+
+        # 1. 创建dom树对象
+        doc = minidom.Document()
+
+        # 2. 创建根结点,并用dom对象添加根结点
+        root_node = doc.createElement("annotation")
+        doc.appendChild(root_node)
+
+        # 3. 创建结点,结点包含一个文本结点, 再将结点加入到根结点
+        folder_node = doc.createElement("folder")
+        folder_value = doc.createTextNode('ZS')
+        folder_node.appendChild(folder_value)
+        root_node.appendChild(folder_node)
+
+        filename_node = doc.createElement("filename")
+        filename_value = doc.createTextNode(filename)
+        filename_node.appendChild(filename_value)
+        root_node.appendChild(filename_node)
+
+        path_node = doc.createElement("path")
+        path_value = doc.createTextNode(img)
+        path_node.appendChild(path_value)
+        root_node.appendChild(path_node)
+
+        source_node = doc.createElement("source")
+        database_node = doc.createElement("database")
+        database_node.appendChild(doc.createTextNode("Unknown"))
+        source_node.appendChild(database_node)
+        root_node.appendChild(source_node)
+
+        size_node = doc.createElement("size")
+        for item, value in zip(["width", "height", "depth"], whd_list):
+            elem = doc.createElement(item)
+            elem.appendChild(doc.createTextNode(str(value)))
+            size_node.appendChild(elem)
+        root_node.appendChild(size_node)
+
+        seg_node = doc.createElement("segmented")
+        seg_node.appendChild(doc.createTextNode(str(0)))
+        root_node.appendChild(seg_node)
+
+        for _, df in df_t.iterrows():
+            obj_node = doc.createElement("object")
+            name_node = doc.createElement("name")
+            name_node.appendChild(doc.createTextNode(str(df['class'])))
+            obj_node.appendChild(name_node)
+
+            pose_node = doc.createElement("pose")
+            pose_node.appendChild(doc.createTextNode("Unspecified"))
+            obj_node.appendChild(pose_node)
+
+            trun_node = doc.createElement("truncated")
+            trun_node.appendChild(doc.createTextNode(str(0)))
+            obj_node.appendChild(trun_node)
+
+            trun_node = doc.createElement("difficult")
+            trun_node.appendChild(doc.createTextNode(str(0)))
+            obj_node.appendChild(trun_node)
+
+            bndbox_node = doc.createElement("bndbox")
+            for item, value in zip(["xmin", "ymin", "xmax", "ymax"], [df['xmin'], df['ymin'], df['xmax'], df['ymax']]):
+                elem = doc.createElement(item)
+                elem.appendChild(doc.createTextNode(str(value)))
+                bndbox_node.appendChild(elem)
+            obj_node.appendChild(bndbox_node)
+            root_node.appendChild(obj_node)
+
+        xml_file = filename.split('.')[0] + '.xml'
+        with open(pathUtil.path_format_join(total_data_dir, xml_file), "w", encoding="utf-8") as f:
+            # 4.writexml()第一个参数是目标文件对象,第二个参数是根节点的缩进格式,第三个参数是其他子节点的缩进格式,
+            # 第四个参数制定了换行格式,第五个参数制定了xml内容的编码。
+            doc.writexml(f, indent='', addindent='\t', newl='\n', encoding="utf-8")

+ 80 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/select_data_util.py

@@ -0,0 +1,80 @@
+"""
+# File       : select_data_util.py
+# Time       :21.7.2 18:42
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+
+def select_class_name_and_spot_id(mydb, class_name, spot_id):
+    sql = 'SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.class_name="%s" and ibb.spot_id=%s' % (class_name, spot_id)
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()
+    mycursor.close()
+    return result
+
+def select_state_name_and_spot_id(mydb, state_name, spot_id):
+    sql = 'SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.state_name="%s" and ibb.spot_id=%s' % (state_name, spot_id)
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()
+    mycursor.close()
+    return result
+
+
+def select_class_name(mydb, class_name):
+    sql = 'SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.class_name="%s"'% class_name
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()
+    mycursor.close()
+    return result
+
+def select_state_name(mydb, state_name):
+    sql = 'SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where cs.state_name="%s"'% state_name
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()
+    mycursor.close()
+    return result
+
+def select_state_name_count_num(mydb, state_name):
+    sql = 'SELECT count(*) from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id where cs.state_name="%s"'% state_name
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()[0]
+    mycursor.close()
+    return result
+
+
+def select_spot_id(mydb, spot_id):
+    sql = 'SELECT path,filename,width,height,depth,state_name,xmin,ymin,xmax,ymax from img_bbox ibb left JOIN class_state cs on ibb.state_id = cs.id LEFT JOIN img_basic ib  on ibb.file_id = ib.id where ibb.spot_id=%s' % spot_id
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()
+    mycursor.close()
+    return result
+
+def select_class_state_class_name(mydb, class_name):
+    sql = 'select state_name from class_state where class_name = "%s"' % class_name
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()
+    mycursor.close()
+    return result
+
+
+def select_spot_name_get_spot_id(mydb, spot):
+    sql = 'select spot_id from spot_basic where spot = %s' % spot
+    mycursor = mydb.cursor()
+    mycursor.execute(sql)
+    result = mycursor.fetchall()[0][0]
+    mycursor.close()
+    return result
+
+# if __name__ == "__main__":
+#     mydb = connect_db_util.connect_db().mydb()
+#     result = select_Data().select_img_bbox_max(mydb)
+#     print(result)

+ 39 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/strUtil.py

@@ -0,0 +1,39 @@
+"""
+# File       : strUtil.py
+# Time       :21.5.25 18:50
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:字符串工具类
+"""
+import sys
+
+#
+# def str2list(strs, split_str = ','):
+#     strs = ''.join(x for x in strs if x.isprintable())
+#     str_list = strs.strip().split(split_str)
+def profile2str(st):
+    return st.replace('"', '').replace('\'', '')
+
+
+
+def is_num(n, key):
+    try:
+        num = int(n)
+        return num
+    except:
+        sys.exit('请确认%s字段的值是数字'%(key))
+
+
+def is_absnum(n, key):
+    num = is_num(n, key)
+    return abs(num)
+
+def num_str(st):
+    try:
+        num = int(st)
+        return num
+    except:
+        return st
+
+

+ 150 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/gen_data/utils/write_xml_jpg_2_all_data_util.py

@@ -0,0 +1,150 @@
+"""
+# File       : write_xml_jpg_2_all_data_util.py
+# Time       :21.7.30 11:28
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import math
+import time
+from gen_data.utils import pathUtil
+from xml.dom import minidom
+import shutil
+import tqdm
+import threading
+
+
+def copy_jpg_thread(all_df, path):
+    all_df = all_df.drop_duplicates()
+    if path.split(":/")[0] != '':
+        mapping = path.split(":/")[0]+":/"
+    else:
+        mapping = ''
+    total_data_path = pathUtil.path_format_join(path, 'TotalData')
+    pathUtil.mkdir_new(total_data_path)
+    threads = []
+    df_list = splitdf(all_df, 5)
+    for df in df_list:
+        t = threading.Thread(target=copy_jpg, args=(mapping, df, total_data_path))
+        threads.append(t)
+    for t in threads:
+        t.setDaemon(True)
+        t.start()
+    for t in threads:
+        t.join()
+    time.sleep(10)
+
+def copy_jpg(mapping, all_df, total_data_path):
+    for key,values in tqdm.tqdm(all_df.groupby('path'),total=len(all_df.groupby('path')), ncols=60):
+        imgshape = values.head(1)[['width', 'height', 'depth']].values[0]
+        # try:
+        # k = ''.join(key[0:-1]) if str(key).startswith('/') else key
+        # old_path = mapping + k
+        old_path = key
+        image_name = pathUtil.path_format(key).split('/')[-1]
+        # print(image_name)
+        # print(total_data_path)
+        new_path = pathUtil.path_format_join(total_data_path, image_name)
+        # print(new_path)
+        # print(old_path)
+        shutil.copyfile(old_path, new_path)
+        write_xml(key, values, imgshape, total_data_path)
+        # except:
+        #     pass
+
+
+def write_xml(img, df_t, whd_list, total_data_dir):
+    """
+    生成xml文件,写入数据后保存
+
+    :param img:添加过特征图像后新图片的路径
+    :param df_t:该图片中特征图像的坐标信息
+    :param whd_list:新图片的长宽信息
+    :param total_data_dir:最终保存xml数据的路径
+    :return:
+    """
+    filename = img.split('/')[-1]
+
+    # 1. 创建dom树对象
+    doc = minidom.Document()
+
+    # 2. 创建根结点,并用dom对象添加根结点
+    root_node = doc.createElement("annotation")
+    doc.appendChild(root_node)
+
+    # 3. 创建结点,结点包含一个文本结点, 再将结点加入到根结点
+    folder_node = doc.createElement("folder")
+    folder_value = doc.createTextNode('ZS')
+    folder_node.appendChild(folder_value)
+    root_node.appendChild(folder_node)
+
+    filename_node = doc.createElement("filename")
+    filename_value = doc.createTextNode(filename)
+    filename_node.appendChild(filename_value)
+    root_node.appendChild(filename_node)
+
+    path_node = doc.createElement("path")
+    path_value = doc.createTextNode(img)
+    path_node.appendChild(path_value)
+    root_node.appendChild(path_node)
+
+    source_node = doc.createElement("source")
+    database_node = doc.createElement("database")
+    database_node.appendChild(doc.createTextNode("Unknown"))
+    source_node.appendChild(database_node)
+    root_node.appendChild(source_node)
+
+    size_node = doc.createElement("size")
+    for item, value in zip(["width", "height", "depth"], whd_list):
+        elem = doc.createElement(item)
+        elem.appendChild(doc.createTextNode(str(value)))
+        size_node.appendChild(elem)
+    root_node.appendChild(size_node)
+
+    seg_node = doc.createElement("segmented")
+    seg_node.appendChild(doc.createTextNode(str(0)))
+    root_node.appendChild(seg_node)
+
+    for _, df in df_t.iterrows():
+        obj_node = doc.createElement("object")
+        name_node = doc.createElement("name")
+        name_node.appendChild(doc.createTextNode(str(df['class_name'])))
+        obj_node.appendChild(name_node)
+
+        pose_node = doc.createElement("pose")
+        pose_node.appendChild(doc.createTextNode("Unspecified"))
+        obj_node.appendChild(pose_node)
+
+        trun_node = doc.createElement("truncated")
+        trun_node.appendChild(doc.createTextNode(str(0)))
+        obj_node.appendChild(trun_node)
+
+        trun_node = doc.createElement("difficult")
+        trun_node.appendChild(doc.createTextNode(str(0)))
+        obj_node.appendChild(trun_node)
+
+        bndbox_node = doc.createElement("bndbox")
+        for item, value in zip(["xmin", "ymin", "xmax", "ymax"], [df['xmin'], df['ymin'], df['xmax'], df['ymax']]):
+            elem = doc.createElement(item)
+            elem.appendChild(doc.createTextNode(str(value)))
+            bndbox_node.appendChild(elem)
+        obj_node.appendChild(bndbox_node)
+        root_node.appendChild(obj_node)
+
+    xml_file = filename.split('.jpg')[0] + '.xml'
+    with open(pathUtil.path_format_join(total_data_dir, xml_file), "w", encoding="utf-8") as f:
+        # 4.writexml()第一个参数是目标文件对象,第二个参数是根节点的缩进格式,第三个参数是其他子节点的缩进格式,
+        # 第四个参数制定了换行格式,第五个参数制定了xml内容的编码。
+        doc.writexml(f, indent='', addindent='\t', newl='\n', encoding="utf-8")
+
+def splitdf(df, num):
+    linenum = math.floor(len(df) / num)
+    pdlist = []
+    for i in range(num):
+        pd1 = df[i * linenum:(i + 1) * linenum]
+        pdlist.append(pd1)
+    #         print(len(pd1))
+    pd1 = df[(num - 1) * linenum:len(df)]
+    pdlist.append(pd1)
+    return pdlist

+ 251 - 0
code/data_manage/test_util/Qt5/gen_tfrecord_ui/run_linux.py

@@ -0,0 +1,251 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'frame.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.4
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again.  Do not edit this file unless you know what you are doing.
+import os
+import sys
+import pandas as pd
+import mysql.connector
+from gen_data.utils import select_data_util, write_xml_jpg_2_all_data_util, pathUtil
+from gen_data import check_data_step2, classFeatureImages_step3, gen_class_index_step4, genAnn_step5, \
+    splitTrainVal_step6, gen_pb_tfrecord_step7
+
+
+class run():
+    def __init__(self, host, user, passwd, database):
+        self.host = host
+        self.user = user
+        self.passwd = passwd
+        self.database = database
+
+    @staticmethod
+    def lineEdit_9_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(';'):
+            line_list = lines.split(',')
+            spot_id = int(line_list[1])
+            percent = float(line_list[2])
+            if lines[0] == 'all' and spot_id == 0:
+                return None
+
+            if len(line_list) == 3 and spot_id >= 0 and 0 < percent <= 1:
+                lists.append([line_list[0].upper(), int(spot_id), float(percent)])
+            else:
+                return None
+
+        return lists
+
+    @staticmethod
+    def lineEdit_18_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(','):
+            lists.extend(lines.split(','))
+        return lists
+
+    @staticmethod
+    def my_db_connect(host, user, passwd, database):
+        """
+        生成mysql连接
+        :return:
+        """
+        my_db = mysql.connector.connect(host=host, user=user, passwd=passwd, database=database, buffered=True)
+        return my_db
+
+    def push_button(self):
+        """
+        数据库测试
+        :return:
+        :rtype:
+        """
+        try:
+            my_db = self.my_db_connect(self.host, self.user, self.passwd,
+                               self.database)
+            if my_db:
+                print('数据库连接成功')
+            else:
+                print('数据库连接失败')
+        except:
+            print('数据库连接失败')
+
+
+
+    def push_button_1(self, filter, save_path):
+        """
+        工具1
+        :return:
+        :rtype:
+        """
+        global filter_list
+        if filter and save_path:
+            try:
+                filter_list = self.lineEdit_9_str2list(filter)
+            except:
+                print('筛选条件的数据格式有问题,请检查!')
+            if filter_list:
+                all_df = pd.DataFrame()
+                my_db = self.my_db_connect(self.host, self.user, self.passwd,self.database)
+                all_df = self.list_items_select_data_2_df(my_db, filter_list, all_df)
+
+                write_xml_jpg_2_all_data_util.copy_jpg_thread(all_df, pathUtil.path_format(save_path))
+                self.step2_to_step7(save_path)
+            # k16,1,0.3
+            else:
+                print('筛选条件错误,请检查')
+
+        else:
+            print('筛选条件或保存路径不能为空')
+
+    # def push_button_2(self):
+    #     if self.lineEdit_19.text() and self.lineEdit_18.text() and self.lineEdit_24.text():
+    #
+    #         try:
+    #             filter_list = self.lineEdit_18_str2list(self.lineEdit_18.text())
+    #             if filter_list:
+    #                 try:
+    #                     percent = float(self.lineEdit_19.text())
+    #                     if 0 < percent <= 1:
+    #
+    #                         print(filter_list)
+    #                     else:
+    #                         self.show_message_num()
+    #                 except:
+    #                     self.show_message_num()
+    #             else:
+    #                 self.show_filter_message_common_error()
+    #         except:
+    #             self.show_filter_message_common_error()
+    #     else:
+    #         self.show_message_isempty()
+
+    # def push_button_2(self):
+    #     '''
+    #     工具2
+    #     :return:
+    #     :rtype:
+    #     '''
+    #     global filter_list
+    #     if self.lineEdit_18.text():
+    #         my_db = self.my_db_connect(self.host, self.user, self.passwd,
+    #                                self.database)
+    #         try:
+    #             result_str = self.select_class_name_count_num(my_db, self.lineEdit_18.text())
+    #             self.show_class_name_count_num(result_str)
+    #         except:
+    #             self.show_database_error()
+    #     else:
+    #         self.show_message_isempty()
+    #
+    # def push_button_3(self):
+    #     """
+    #     工具3
+    #     :return:
+    #     :rtype:
+    #     """
+    #     if self.lineEdit_25.text():
+    #         spot_name = self.lineEdit_25.text().split('/')[-1].split('\\')[-1]
+    #         my_db = self.my_db_connect(self.host, self.user, self.passwd,
+    #                                self.database)
+    #         try:
+    #             spot_id = select_data_util.select_spot_name_get_spot_id(my_db, spot_name)
+    #             self.show_spot_id(spot_name, spot_id)
+    #         except:
+    #             self.show_database_error()
+    #         my_db.close()
+    #     else:
+    #         self.show_message_isempty()
+
+    def list_items_select_data_2_df(self, my_db, lists, all_df):
+        for li in lists:
+            if li[0].upper() == 'ALL':
+                df = self.select_spot_id_2_df(my_db, li[1], li[2])
+                all_df = all_df.append(df)
+            elif li[1] == 0:
+                df = self.select_class_name_2_df(my_db, li[0], li[2])
+                all_df = all_df.append(df)
+            else:
+                df = self.select_class_name_and_spot_id_2_df(my_db, li[0], li[1], li[2])
+                all_df = all_df.append(df)
+        return all_df
+
+    @staticmethod
+    def select_class_name_and_spot_id_2_df(mydb, class_name, spot_id, percent):
+
+        if '-' in class_name:
+            state_name_list = select_data_util.select_state_name_and_spot_id(mydb, class_name, spot_id)
+        else:
+            state_name_list = select_data_util.select_class_name_and_spot_id(mydb, class_name, spot_id)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def select_class_name_2_df(mydb, class_name, percent):
+
+        if '-' in class_name:
+            state_name_list = select_data_util.select_state_name(mydb, class_name)
+        else:
+            state_name_list = select_data_util.select_class_name(mydb, class_name)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def select_class_name_count_num(mydb, class_name):
+        results_str = ''
+        if '-' in class_name:
+            result = select_data_util.select_state_name_count_num(mydb, class_name.upper())
+            results_str = results_str + ',' + '%s:%d' % (class_name, result[0])
+        else:
+            state_name_list = select_data_util.select_class_state_class_name(mydb, class_name.upper())
+            print(state_name_list)
+            for state_name in state_name_list:
+                result = select_data_util.select_state_name_count_num(mydb, state_name[0].upper())
+                results_str = results_str + ',' + '%s:%d' % (state_name[0], result[0])
+        results_str = results_str[1:]
+        return results_str
+
+    @staticmethod
+    def select_spot_id_2_df(mydb, spot_id, percent):
+
+        state_name_list = select_data_util.select_spot_id(mydb, spot_id)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    def step2_to_step7(self, path: str):
+        print(check_data_step2.__file__)
+        flag, mess = check_data_step2.checkFile(path).main()
+        if flag == 0:
+            print(mess)
+        else:
+            classFeatureImages_step3.classFeatureImages(path).main()
+            gen_class_index_step4.genClassIndex(path).main()
+            genAnn_step5.ganAnn(path).main()
+            splitTrainVal_step6.splitTrainVal(path).main()
+            gen_pb_tfrecord_step7.genPb(path).main()
+            gen_pb_tfrecord_step7.genTfrecord(path).main()
+            print('数据切分完成!')
+
+
+
+if __name__ == "__main__":
+    run = run('192.168.20.249', 'root', 'root', 'db_img')
+    filter = input('请输入筛选条件:')
+    save_path = input('请输入保存路径:')
+    print('正在测试')
+    run.push_button_1(filter, save_path)
+

+ 515 - 0
code/data_manage/test_util/Qt5/mysql_find.py

@@ -0,0 +1,515 @@
+"""
+# File       : mysql_find.py
+# Time       :21.10.12 17:32
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'frame.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.4
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again.  Do not edit this file unless you know what you are doing.
+import os
+import sys
+import pandas as pd
+import mysql.connector
+from PyQt5 import QtCore, QtWidgets
+from PyQt5.QtWidgets import QFileDialog, QMessageBox
+
+
+class Ui_Form(QtWidgets.QMainWindow):
+    def __init__(self):
+        super(Ui_Form, self).__init__()
+        self.setupUi(self)
+        self.retranslateUi(self)
+        self.cwd = os.getcwd()
+
+    def setupUi(self, Form):
+        Form.setObjectName("Form")
+        Form.resize(648, 783)
+        self.widget = QtWidgets.QWidget(Form)
+        self.widget.setGeometry(QtCore.QRect(60, 35, 520, 616))
+        self.widget.setObjectName("widget")
+        self.verticalLayout_5 = QtWidgets.QVBoxLayout(self.widget)
+        self.verticalLayout_5.setContentsMargins(0, 0, 0, 0)
+        self.verticalLayout_5.setObjectName("verticalLayout_5")
+        self.verticalLayout = QtWidgets.QVBoxLayout()
+        self.verticalLayout.setObjectName("verticalLayout")
+        self.label_15 = QtWidgets.QLabel(self.widget)
+        self.label_15.setObjectName("label_15")
+        self.verticalLayout.addWidget(self.label_15, 0, QtCore.Qt.AlignHCenter)
+        self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_3.setObjectName("horizontalLayout_3")
+        self.label_3 = QtWidgets.QLabel(self.widget)
+        self.label_3.setObjectName("label_3")
+        self.horizontalLayout_3.addWidget(self.label_3)
+        self.lineEdit_3 = QtWidgets.QLineEdit('127.0.0.1', self.widget)
+        self.lineEdit_3.setObjectName("lineEdit_3")
+        self.horizontalLayout_3.addWidget(self.lineEdit_3)
+        self.label_4 = QtWidgets.QLabel(self.widget)
+        self.label_4.setObjectName("label_4")
+        self.horizontalLayout_3.addWidget(self.label_4)
+        self.lineEdit_4 = QtWidgets.QLineEdit('3306', self.widget)
+        self.lineEdit_4.setObjectName("lineEdit_4")
+        self.horizontalLayout_3.addWidget(self.lineEdit_4)
+        self.verticalLayout.addLayout(self.horizontalLayout_3)
+        self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
+        self.label_5 = QtWidgets.QLabel(self.widget)
+        self.label_5.setObjectName("label_5")
+        self.horizontalLayout_2.addWidget(self.label_5)
+        self.lineEdit_5 = QtWidgets.QLineEdit('root', self.widget)
+        self.lineEdit_5.setObjectName("lineEdit_5")
+        self.horizontalLayout_2.addWidget(self.lineEdit_5)
+        self.label_6 = QtWidgets.QLabel(self.widget)
+        self.label_6.setObjectName("label_6")
+        self.horizontalLayout_2.addWidget(self.label_6)
+        self.lineEdit_6 = QtWidgets.QLineEdit('root', self.widget)
+        self.lineEdit_6.setEchoMode(QtWidgets.QLineEdit.Password)
+        self.lineEdit_6.setObjectName("lineEdit_6")
+        self.horizontalLayout_2.addWidget(self.lineEdit_6)
+        self.verticalLayout.addLayout(self.horizontalLayout_2)
+        self.horizontalLayout = QtWidgets.QHBoxLayout()
+        self.horizontalLayout.setObjectName("horizontalLayout")
+        self.label_7 = QtWidgets.QLabel(self.widget)
+        self.label_7.setObjectName("label_7")
+        self.horizontalLayout.addWidget(self.label_7)
+        self.lineEdit_7 = QtWidgets.QLineEdit('db_img', self.widget)
+        self.lineEdit_7.setObjectName("lineEdit_7")
+        self.horizontalLayout.addWidget(self.lineEdit_7)
+        self.pushButton = QtWidgets.QPushButton(self.widget)
+        self.pushButton.setObjectName("pushButton")
+        self.horizontalLayout.addWidget(self.pushButton)
+        self.verticalLayout.addLayout(self.horizontalLayout)
+        self.label_8 = QtWidgets.QLabel(self.widget)
+        self.label_8.setText("")
+        self.label_8.setObjectName("label_8")
+        self.verticalLayout.addWidget(self.label_8)
+        self.label_16 = QtWidgets.QLabel(self.widget)
+        self.label_16.setObjectName("label_16")
+        self.verticalLayout.addWidget(self.label_16)
+        self.verticalLayout_5.addLayout(self.verticalLayout)
+        self.verticalLayout_2 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_2.setObjectName("verticalLayout_2")
+        self.label = QtWidgets.QLabel(self.widget)
+        self.label.setObjectName("label")
+        self.verticalLayout_2.addWidget(self.label, 0, QtCore.Qt.AlignHCenter)
+        self.label_10 = QtWidgets.QLabel(self.widget)
+        self.label_10.setObjectName("label_10")
+        self.verticalLayout_2.addWidget(self.label_10)
+        self.label_12 = QtWidgets.QLabel(self.widget)
+        self.label_12.setObjectName("label_12")
+        self.verticalLayout_2.addWidget(self.label_12)
+        self.label_13 = QtWidgets.QLabel(self.widget)
+        self.label_13.setObjectName("label_13")
+        self.verticalLayout_2.addWidget(self.label_13)
+        self.label_14 = QtWidgets.QLabel(self.widget)
+        self.label_14.setObjectName("label_14")
+        self.verticalLayout_2.addWidget(self.label_14)
+        self.label_11 = QtWidgets.QLabel(self.widget)
+        self.label_11.setObjectName("label_11")
+        self.verticalLayout_2.addWidget(self.label_11)
+        self.horizontalLayout_4 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_4.setObjectName("horizontalLayout_4")
+        self.label_9 = QtWidgets.QLabel(self.widget)
+        self.label_9.setObjectName("label_9")
+        self.horizontalLayout_4.addWidget(self.label_9)
+        self.lineEdit_9 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_9.setObjectName("lineEdit_9")
+        self.horizontalLayout_4.addWidget(self.lineEdit_9)
+        self.verticalLayout_2.addLayout(self.horizontalLayout_4)
+        self.label_27 = QtWidgets.QLabel(self.widget)
+        self.label_27.setObjectName("label_27")
+        self.verticalLayout_2.addWidget(self.label_27)
+        self.label_28 = QtWidgets.QLabel(self.widget)
+        self.label_28.setObjectName("label_28")
+        self.verticalLayout_2.addWidget(self.label_28)
+        self.horizontalLayout_7 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_7.setObjectName("horizontalLayout_7")
+        self.label_21 = QtWidgets.QLabel(self.widget)
+        self.label_21.setObjectName("label_21")
+        self.horizontalLayout_7.addWidget(self.label_21)
+        self.pushButton_21 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_21.setObjectName("pushButton_21")
+        self.horizontalLayout_7.addWidget(self.pushButton_21)
+        self.lineEdit_21 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_21.setObjectName("lineEdit_21")
+        self.horizontalLayout_7.addWidget(self.lineEdit_21)
+        self.verticalLayout_2.addLayout(self.horizontalLayout_7)
+        self.pushButton_1 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_1.setObjectName("pushButton_1")
+        self.verticalLayout_2.addWidget(self.pushButton_1)
+        self.label_26 = QtWidgets.QLabel(self.widget)
+        self.label_26.setText("")
+        self.label_26.setObjectName("label_26")
+        self.verticalLayout_2.addWidget(self.label_26)
+        self.label_17 = QtWidgets.QLabel(self.widget)
+        self.label_17.setObjectName("label_17")
+        self.verticalLayout_2.addWidget(self.label_17)
+        self.verticalLayout_5.addLayout(self.verticalLayout_2)
+        self.verticalLayout_3 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_3.setObjectName("verticalLayout_3")
+        self.label_23 = QtWidgets.QLabel(self.widget)
+        self.label_23.setObjectName("label_23")
+        self.verticalLayout_3.addWidget(self.label_23, 0, QtCore.Qt.AlignHCenter)
+        self.horizontalLayout_5 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_5.setObjectName("horizontalLayout_5")
+        self.label_18 = QtWidgets.QLabel(self.widget)
+        self.label_18.setObjectName("label_18")
+        self.horizontalLayout_5.addWidget(self.label_18)
+        self.lineEdit_18 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_18.setObjectName("lineEdit_18")
+        self.horizontalLayout_5.addWidget(self.lineEdit_18)
+        self.verticalLayout_3.addLayout(self.horizontalLayout_5)
+        self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_8.setObjectName("horizontalLayout_8")
+        self.verticalLayout_3.addLayout(self.horizontalLayout_8)
+        self.pushButton_2 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_2.setObjectName("pushButton_2")
+        self.verticalLayout_3.addWidget(self.pushButton_2)
+        self.label_2 = QtWidgets.QLabel(self.widget)
+        self.label_2.setText("")
+        self.label_2.setObjectName("label_2")
+        self.verticalLayout_3.addWidget(self.label_2)
+        self.label_22 = QtWidgets.QLabel(self.widget)
+        self.label_22.setObjectName("label_22")
+        self.verticalLayout_3.addWidget(self.label_22)
+        self.verticalLayout_5.addLayout(self.verticalLayout_3)
+        self.verticalLayout_4 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_4.setObjectName("verticalLayout_4")
+        self.label_20 = QtWidgets.QLabel(self.widget)
+        self.label_20.setEnabled(True)
+        self.label_20.setObjectName("label_20")
+        self.verticalLayout_4.addWidget(self.label_20, 0, QtCore.Qt.AlignHCenter)
+        self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_6.setObjectName("horizontalLayout_6")
+        self.label_25 = QtWidgets.QLabel(self.widget)
+        self.label_25.setObjectName("label_25")
+        self.horizontalLayout_6.addWidget(self.label_25)
+        self.pushButton_25 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_25.setObjectName("pushButton_25")
+        self.horizontalLayout_6.addWidget(self.pushButton_25)
+        self.lineEdit_25 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_25.setObjectName("lineEdit_25")
+        self.horizontalLayout_6.addWidget(self.lineEdit_25)
+        self.verticalLayout_4.addLayout(self.horizontalLayout_6)
+        self.pushButton_3 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_3.setObjectName("pushButton_3")
+        self.verticalLayout_4.addWidget(self.pushButton_3)
+        self.verticalLayout_5.addLayout(self.verticalLayout_4)
+
+        self.pushButton_21.clicked.connect(self.openfolder_images_path_lineEdit_21)
+        # self.pushButton_24.clicked.connect(self.openfolder_images_path_lineEdit_24)
+        self.pushButton_25.clicked.connect(self.openfolder_images_path_lineEdit_25)
+        self.pushButton_1.clicked.connect(self.push_button_1)
+        self.pushButton_2.clicked.connect(self.push_button_2)
+        self.pushButton_3.clicked.connect(self.push_button_3)
+
+        self.retranslateUi(Form)
+        QtCore.QMetaObject.connectSlotsByName(Form)
+
+    def retranslateUi(self, Form):
+        _translate = QtCore.QCoreApplication.translate
+        Form.setWindowTitle(_translate("Form", "Form"))
+        self.label_15.setText(_translate("Form", "数据库配置信息"))
+        self.label_3.setText(_translate("Form", "Host:  "))
+        self.label_4.setText(_translate("Form", "Port:     "))
+        self.label_5.setText(_translate("Form", "User:  "))
+        self.label_6.setText(_translate("Form", "Password: "))
+        self.label_7.setText(_translate("Form", "dbName:"))
+        self.pushButton.setText(_translate("Form", "数据库连接测试"))
+        self.label_16.setText(_translate("Form",
+                                         "--------------------------------------------------------------------------------------"))
+        self.label.setText(_translate("Form", "工具1:根据类别和项目名称提取特征数据生成tfrecord"))
+        self.label_10.setText(_translate("Form", "数据表名:"))
+        self.label_12.setText(_translate("Form", "    class_name 类别名称(不是state_name)要获取全部,填写all"))
+        self.label_13.setText(_translate("Form", "    spot_id 地点ID,要获取全部,填写0(此时的class_name不能为all)"))
+        self.label_14.setText(_translate("Form", "    percent 获取数据的百分比(0-1],要获取全部,填写1"))
+        self.label_11.setText(_translate("Form", "    示例:E1,0,0.2;K10,12,0.3;all,1,1"))
+        self.label_9.setText(_translate("Form", "数据表名称:"))
+        self.label_18.setText(_translate("Form", "ID名称:"))
+        self.pushButton_3.setText(_translate("Form", "提交"))
+
+    def show_message_folder(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "空值提示", "没有获取到相应的文件夹路径!")
+
+    def show_message_isempty(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "文本框有空值错误提示", "文本框不能为空,请检查!")
+
+    def show_filter_message_list_len_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "错误提示", "筛选条件文本框错误,请检查!")
+
+    def show_filter_message_common_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "错误提示", "非通用字段文本框错误,请检查!")
+
+    def show_message_num(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数字错误提示", "输入的不是数字,或者输入的数字小于0!")
+
+    def openfolder_images_path_lineEdit_21(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '选择数据保存路径', self.cwd)
+        self.lineEdit_21.setText(openfolder_path)
+
+    def openfolder_images_path_lineEdit_24(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '选择数据保存路径', self.cwd)
+        self.lineEdit_24.setText(openfolder_path)
+
+    def openfolder_images_path_lineEdit_25(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '打开文件夹', self.cwd)
+        self.lineEdit_25.setText(openfolder_path)
+
+    def show_spot_id(self, spot_name, spot_id):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "查询结果提示", "%s对应的ID为:%s" % (spot_name, str(spot_id)))
+
+    def show_class_name_count_num(self, result):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "查询结果提示", "类别对应的数据量为:%s" % (result))
+
+    def show_database_error(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数据库错误提示", "数据库查询错误!请检查查询条件")
+
+    def show_message_succes(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "完成提示", "数据生成完毕!")
+
+    @staticmethod
+    def lineEdit_9_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(';'):
+            line_list = lines.split(',')
+            spot_id = int(line_list[1])
+            percent = float(line_list[2])
+            if lines[0] == 'all' and spot_id == 0:
+                return None
+
+            if len(line_list) == 3 and spot_id >= 0 and 0 < percent <= 1:
+                lists.append([line_list[0].upper(), int(spot_id), float(percent)])
+            else:
+                return None
+
+        return lists
+
+    @staticmethod
+    def lineEdit_18_str2list(lineEdit):
+        lists = []
+        if lineEdit.strip() == '':
+            return None
+        for lines in lineEdit.split(','):
+            lists.extend(lines.split(','))
+        return lists
+
+    @staticmethod
+    def my_db_connect(host, user, passwd, database):
+        """
+        生成mysql连接
+        :return:
+        """
+        my_db = mysql.connector.connect(host=host, user=user, passwd=passwd, database=database, buffered=True)
+        return my_db
+
+    def push_button_1(self):
+        global filter_list
+        if self.lineEdit_9.text() and self.lineEdit_21.text():
+            try:
+                filter_list = self.lineEdit_9_str2list(self.lineEdit_9.text())
+            except:
+                self.show_filter_message_list_len_error()
+            if filter_list:
+                all_df = pd.DataFrame()
+                my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(),
+                                           self.lineEdit_7.text())
+                all_df = self.list_items_select_data_2_df(my_db, filter_list, all_df)
+
+                write_xml_jpg_2_all_data_util.copy_jpg_thread(all_df, pathUtil.path_format(self.lineEdit_21.text()))
+                self.step2_to_step7(self.lineEdit_21.text())
+            # k16,1,0.3
+            else:
+                self.show_filter_message_list_len_error()
+
+        else:
+            self.show_message_isempty()
+
+    # def push_button_2(self):
+    #     if self.lineEdit_19.text() and self.lineEdit_18.text() and self.lineEdit_24.text():
+    #
+    #         try:
+    #             filter_list = self.lineEdit_18_str2list(self.lineEdit_18.text())
+    #             if filter_list:
+    #                 try:
+    #                     percent = float(self.lineEdit_19.text())
+    #                     if 0 < percent <= 1:
+    #
+    #                         print(filter_list)
+    #                     else:
+    #                         self.show_message_num()
+    #                 except:
+    #                     self.show_message_num()
+    #             else:
+    #                 self.show_filter_message_common_error()
+    #         except:
+    #             self.show_filter_message_common_error()
+    #     else:
+    #         self.show_message_isempty()
+
+    def push_button_2(self):
+        global filter_list
+        if self.lineEdit_18.text():
+            my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(),
+                                       self.lineEdit_7.text())
+            try:
+                result_str = self.select_class_name_count_num(my_db, self.lineEdit_18.text())
+                self.show_class_name_count_num(result_str)
+            except:
+                self.show_database_error()
+        else:
+            self.show_message_isempty()
+
+    def push_button_3(self):
+        if self.lineEdit_25.text():
+            spot_name = self.lineEdit_25.text().split('/')[-1].split('\\')[-1]
+            my_db = self.my_db_connect(self.lineEdit_3.text(), self.lineEdit_5.text(), self.lineEdit_6.text(),
+                                       self.lineEdit_7.text())
+            try:
+                spot_id = select_data_util.select_spot_name_get_spot_id(my_db, spot_name)
+                self.show_spot_id(spot_name, spot_id)
+            except:
+                self.show_database_error()
+            my_db.close()
+        else:
+            self.show_message_isempty()
+
+    def list_items_select_data_2_df(self, my_db, lists, all_df):
+        for li in lists:
+            if li[0].upper() == 'ALL':
+                df = self.select_spot_id_2_df(my_db, li[1], li[2])
+                all_df = all_df.append(df)
+            elif li[1] == 0:
+                df = self.select_class_name_2_df(my_db, li[0], li[2])
+                all_df = all_df.append(df)
+            else:
+                df = self.select_class_name_and_spot_id_2_df(my_db, li[0], li[1], li[2])
+                all_df = all_df.append(df)
+        return all_df
+
+    @staticmethod
+    def select_class_name_and_spot_id_2_df(mydb, class_name, spot_id, percent):
+
+        if '-' in class_name:
+            state_name_list = select_data_util.select_state_name_and_spot_id(mydb, class_name, spot_id)
+        else:
+            state_name_list = select_data_util.select_class_name_and_spot_id(mydb, class_name, spot_id)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def select_class_name_2_df(mydb, class_name, percent):
+
+        if '-' in class_name:
+            state_name_list = select_data_util.select_state_name(mydb, class_name)
+        else:
+            state_name_list = select_data_util.select_class_name(mydb, class_name)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def select_class_name_count_num(mydb, class_name):
+        results_str = ''
+        if '-' in class_name:
+            result = select_data_util.select_state_name_count_num(mydb, class_name.upper())
+            results_str = results_str + ',' + '%s:%d' % (class_name, result[0])
+        else:
+            state_name_list = select_data_util.select_class_state_class_name(mydb, class_name.upper())
+            print(state_name_list)
+            for state_name in state_name_list:
+                result = select_data_util.select_state_name_count_num(mydb, state_name[0].upper())
+                results_str = results_str + ',' + '%s:%d' % (state_name[0], result[0])
+        results_str = results_str[1:]
+        return results_str
+
+    @staticmethod
+    def select_spot_id_2_df(mydb, spot_id, percent):
+
+        state_name_list = select_data_util.select_spot_id(mydb, spot_id)
+        total_df = pd.DataFrame(state_name_list,
+                                columns=['path', 'filename', 'width', 'height', 'depth', 'class_name', 'xmin', 'ymin',
+                                         'xmax', 'ymax'])
+        total_df = total_df.sample(frac=percent)
+        return total_df
+
+    @staticmethod
+    def step2_to_step7(path):
+
+        check_data_step2.checkFile(path).main()
+        classFeatureImages_step3.classFeatureImages(path).main()
+        gen_class_index_step4.genClassIndex(path).main()
+        genAnn_step5.ganAnn(path).main()
+        splitTrainVal_step6.splitTrainVal(path).main()
+        gen_pb_tfrecord_step7.genPb(path).main()
+        gen_pb_tfrecord_step7.genTfrecord(path).main()
+
+
+if __name__ == "__main__":
+    app = QtWidgets.QApplication(sys.argv)
+    QMainWindow = QtWidgets.QMainWindow()
+    ui = Ui_Form()
+    ui.setupUi(QMainWindow)
+    QMainWindow.show()
+    sys.exit(app.exec_())

+ 8 - 0
code/data_manage/test_util/Qt5/videos2imgs/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py.py
+# Time       :21.6.24 16:01
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 519 - 0
code/data_manage/test_util/Qt5/videos2imgs/form1.3.py

@@ -0,0 +1,519 @@
+# -*- coding: utf-8 -*-
+
+# Form implementation generated from reading ui file 'form1.3.ui'
+#
+# Created by: PyQt5 UI code generator 5.15.4
+#
+# WARNING: Any manual changes made to this file will be lost when pyuic5 is
+# run again.  Do not edit this file unless you know what you are doing.
+
+import sys
+import cv2
+import time
+import math
+import os
+import glob
+import numpy as np
+from tqdm import tqdm
+from PyQt5 import QtCore, QtWidgets
+from PyQt5.QtWidgets import QFileDialog, QMessageBox
+
+
+class Ui_Form(QtWidgets.QMainWindow):
+    def __init__(self):
+        super(Ui_Form, self).__init__()
+        self.setupUi(self)
+        self.retranslateUi(self)
+        self.cwd = os.getcwd()
+
+    def setupUi(self, Form):
+        Form.setObjectName("Form")
+        Form.resize(567, 502)
+        Form.setStyleSheet("critical{font-color:red}")
+        self.widget = QtWidgets.QWidget(Form)
+        self.widget.setGeometry(QtCore.QRect(80, 30, 389, 392))
+        self.widget.setObjectName("widget")
+        self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.widget)
+        self.verticalLayout_3.setContentsMargins(0, 0, 0, 0)
+        self.verticalLayout_3.setObjectName("verticalLayout_3")
+        self.verticalLayout = QtWidgets.QVBoxLayout()
+        self.verticalLayout.setObjectName("verticalLayout")
+        self.horizontalLayout = QtWidgets.QHBoxLayout()
+        self.horizontalLayout.setObjectName("horizontalLayout")
+        self.label_9 = QtWidgets.QLabel(self.widget)
+        self.label_9.setStyleSheet("color:rgb(255, 0, 0)")
+        self.label_9.setObjectName("label_9")
+        self.horizontalLayout.addWidget(self.label_9)
+        self.label_1 = QtWidgets.QLabel(self.widget)
+        self.label_1.setObjectName("label_1")
+        self.horizontalLayout.addWidget(self.label_1)
+        self.pushButton = QtWidgets.QPushButton(self.widget)
+        self.pushButton.setObjectName("pushButton")
+        self.horizontalLayout.addWidget(self.pushButton)
+        self.lineEdit_3 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_3.setMinimumSize(QtCore.QSize(200, 0))
+        self.lineEdit_3.setObjectName("lineEdit_3")
+        self.horizontalLayout.addWidget(self.lineEdit_3)
+        self.verticalLayout.addLayout(self.horizontalLayout)
+        self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_3.setObjectName("horizontalLayout_3")
+        self.label_18 = QtWidgets.QLabel(self.widget)
+        self.label_18.setStyleSheet("color:rgb(255, 0, 0)")
+        self.label_18.setObjectName("label_18")
+        self.horizontalLayout_3.addWidget(self.label_18)
+        self.label_2 = QtWidgets.QLabel(self.widget)
+        self.label_2.setObjectName("label_2")
+        self.horizontalLayout_3.addWidget(self.label_2)
+        self.pushButton_2 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_2.setObjectName("pushButton_2")
+        self.horizontalLayout_3.addWidget(self.pushButton_2)
+        self.lineEdit_4 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_4.setMinimumSize(QtCore.QSize(200, 0))
+        self.lineEdit_4.setObjectName("lineEdit_4")
+        self.horizontalLayout_3.addWidget(self.lineEdit_4)
+        self.verticalLayout.addLayout(self.horizontalLayout_3)
+        self.horizontalLayout_4 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_4.setObjectName("horizontalLayout_4")
+        self.label_3 = QtWidgets.QLabel(self.widget)
+        self.label_3.setObjectName("label_3")
+        self.horizontalLayout_4.addWidget(self.label_3)
+        self.lineEdit = QtWidgets.QLineEdit('mp4,mov,avi,wmv,m4v,flv', self.widget)
+        self.lineEdit.setMinimumSize(QtCore.QSize(200, 0))
+        self.lineEdit.setObjectName("lineEdit")
+        self.horizontalLayout_4.addWidget(self.lineEdit)
+        self.verticalLayout.addLayout(self.horizontalLayout_4)
+        self.horizontalLayout_5 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_5.setObjectName("horizontalLayout_5")
+        self.label_20 = QtWidgets.QLabel(self.widget)
+        self.label_20.setStyleSheet("color:rgb(255, 0, 0)")
+        self.label_20.setObjectName("label_20")
+        self.horizontalLayout_5.addWidget(self.label_20)
+        self.label_4 = QtWidgets.QLabel(self.widget)
+        self.label_4.setObjectName("label_4")
+        self.horizontalLayout_5.addWidget(self.label_4)
+        self.lineEdit_2 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_2.setMinimumSize(QtCore.QSize(200, 0))
+        self.lineEdit_2.setObjectName("lineEdit_2")
+        self.horizontalLayout_5.addWidget(self.lineEdit_2)
+        self.verticalLayout.addLayout(self.horizontalLayout_5)
+        self.horizontalLayout_6 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_6.setObjectName("horizontalLayout_6")
+        self.label_6 = QtWidgets.QLabel(self.widget)
+        self.label_6.setObjectName("label_6")
+        self.horizontalLayout_6.addWidget(self.label_6)
+        self.radioButton = QtWidgets.QRadioButton(self.widget)
+        self.radioButton.setObjectName("radioButton")
+        self.radioButton.setEnabled(False)
+        self.radioButton.setChecked(True)
+        self.horizontalLayout_6.addWidget(self.radioButton)
+        self.radioButton_2 = QtWidgets.QRadioButton(self.widget)
+        self.radioButton_2.setObjectName("radioButton_2")
+        self.radioButton_2.setEnabled(False)
+        self.horizontalLayout_6.addWidget(self.radioButton_2)
+        self.verticalLayout.addLayout(self.horizontalLayout_6)
+        self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
+        self.label_7 = QtWidgets.QLabel(self.widget)
+        self.label_7.setObjectName("label_7")
+        self.horizontalLayout_2.addWidget(self.label_7)
+        self.lineEdit_6 = QtWidgets.QLineEdit(self.widget)
+        self.lineEdit_6.setObjectName("lineEdit_6")
+        self.horizontalLayout_2.addWidget(self.lineEdit_6)
+        self.pushButton_4 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_4.setObjectName("pushButton_4")
+        self.horizontalLayout_2.addWidget(self.pushButton_4)
+        self.verticalLayout.addLayout(self.horizontalLayout_2)
+        self.horizontalLayout_7 = QtWidgets.QHBoxLayout()
+        self.horizontalLayout_7.setObjectName("horizontalLayout_7")
+        self.label_8 = QtWidgets.QLabel(self.widget)
+        self.label_8.setStyleSheet("color:rgb(255, 0, 0)")
+        self.label_8.setObjectName("label_8")
+        self.horizontalLayout_7.addWidget(self.label_8)
+        self.label_5 = QtWidgets.QLabel(self.widget)
+        self.label_5.setObjectName("label_5")
+        self.horizontalLayout_7.addWidget(self.label_5)
+        self.lineEdit_5 = QtWidgets.QLineEdit('20', self.widget)
+        self.lineEdit_5.setObjectName("lineEdit_5")
+        self.horizontalLayout_7.addWidget(self.lineEdit_5)
+        self.verticalLayout.addLayout(self.horizontalLayout_7)
+        self.pushButton_3 = QtWidgets.QPushButton(self.widget)
+        self.pushButton_3.setObjectName("pushButton_3")
+        self.verticalLayout.addWidget(self.pushButton_3)
+        self.verticalLayout_3.addLayout(self.verticalLayout)
+        self.verticalLayout_2 = QtWidgets.QVBoxLayout()
+        self.verticalLayout_2.setObjectName("verticalLayout_2")
+        self.label_10 = QtWidgets.QLabel(self.widget)
+        self.label_10.setObjectName("label_10")
+        self.verticalLayout_2.addWidget(self.label_10)
+        self.label_11 = QtWidgets.QLabel(self.widget)
+        self.label_11.setObjectName("label_11")
+        self.verticalLayout_2.addWidget(self.label_11)
+        self.label_12 = QtWidgets.QLabel(self.widget)
+        self.label_12.setObjectName("label_12")
+        self.verticalLayout_2.addWidget(self.label_12)
+        self.label_13 = QtWidgets.QLabel(self.widget)
+        self.label_13.setObjectName("label_13")
+        self.verticalLayout_2.addWidget(self.label_13)
+        self.label_14 = QtWidgets.QLabel(self.widget)
+        self.label_14.setObjectName("label_14")
+        self.verticalLayout_2.addWidget(self.label_14)
+        self.label_15 = QtWidgets.QLabel(self.widget)
+        self.label_15.setObjectName("label_15")
+        self.verticalLayout_2.addWidget(self.label_15)
+        self.label_16 = QtWidgets.QLabel(self.widget)
+        self.label_16.setObjectName("label_16")
+        self.verticalLayout_2.addWidget(self.label_16)
+        self.label_17 = QtWidgets.QLabel(self.widget)
+        self.label_17.setObjectName("label_17")
+        self.verticalLayout_2.addWidget(self.label_17)
+        self.label_19 = QtWidgets.QLabel(self.widget)
+        self.label_19.setObjectName("label_19")
+        self.verticalLayout_2.addWidget(self.label_19)
+        self.verticalLayout_3.addLayout(self.verticalLayout_2)
+
+
+
+
+
+
+        self.pushButton.clicked.connect(self.openfolder_video_path)
+        self.pushButton_2.clicked.connect(self.openfolder_images_path)
+        self.pushButton_3.clicked.connect(self.button_click)
+        self.pushButton_4.clicked.connect(self._calculate_num)
+
+
+        self.retranslateUi(Form)
+        QtCore.QMetaObject.connectSlotsByName(Form)
+
+    def retranslateUi(self, Form):
+        _translate = QtCore.QCoreApplication.translate
+        Form.setWindowTitle(_translate("Form", "Form"))
+        self.label_9.setText(_translate("Form", "*"))
+        self.label_1.setText(_translate("Form", "视频文件夹:  "))
+        self.pushButton.setText(_translate("Form", "选择文件夹"))
+        self.label_18.setText(_translate("Form", "*"))
+        self.label_2.setText(_translate("Form", "图片保存路径:"))
+        self.pushButton_2.setText(_translate("Form", "选择文件夹"))
+        self.label_3.setText(_translate("Form", " 视频文件后缀名:"))
+        self.label_20.setText(_translate("Form", "*"))
+        self.label_4.setText(_translate("Form", " 图片场景名称: "))
+        self.label_6.setText(_translate("Form", " 视频切分规则:"))
+        self.radioButton.setText(_translate("Form", "按帧数切分"))
+        self.radioButton_2.setText(_translate("Form", " 需要图片的总张数"))
+        self.label_7.setText(_translate("Form", " 图片总数:  "))
+        self.pushButton_4.setText(_translate("Form", "计算帧数"))
+        self.label_8.setText(_translate("Form", "*"))
+        self.label_5.setText(_translate("Form", " 切分帧数:"))
+        self.pushButton_3.setText(_translate("Form", "开始切分"))
+        self.label_10.setText(_translate("Form", "使用说明:"))
+        self.label_11.setText(_translate("Form", "1.点击视频文件夹右边的按钮导入需要切分视频的文件夹"))
+        self.label_12.setText(_translate("Form", "2.同样的操作选择图片保存路径"))
+        self.label_13.setText(_translate("Form", "3.视频文件后缀名有新格式时需在后面添加并用英文逗号隔开"))
+        self.label_14.setText(_translate("Form", "4.图片场景名称建议和项目名称保持一致"))
+        self.label_15.setText(_translate("Form", "5.切分规则:5-1帧数切分;5-2:图片总数切分"))
+        self.label_16.setText(_translate("Form", "    5-1.在“切分帧数”中输入数字,默认为每20帧切分一张"))
+        self.label_17.setText(_translate("Form", "    5-2.在图片总数框写入对应数字,点击“计算帧数”即可"))
+        self.label_19.setText(_translate("Form", "6.点击“开始切分”按钮"))
+
+    def _calculate_num(self):
+        """
+        calculate_num()方法的前期准备工作
+
+        :return:
+        """
+        print('计算中...')
+        if self.lineEdit_3.text():
+            num_str = self.lineEdit_6.text()
+            if num_str and num_str.isdigit() and int(num_str) > 0:
+
+                video_dir = self.lineEdit_3.text().replace('\\', '/')
+                # 视频后缀名分隔符格式化
+                if ',' in self.lineEdit.text() and ',' not in self.lineEdit.text():
+                    vidoes_extension_list = [lin for lin in self.lineEdit.text().split(',')]
+                    vidoes_extension = ','.join(set(vidoes_extension_list))
+                    self.calculate_num(video_dir, vidoes_extension, int(num_str))
+
+                elif ',' not in self.lineEdit.text() and ',' not in self.lineEdit.text():
+                    vidoes_extension = self.lineEdit.text()
+                    self.calculate_num(video_dir, vidoes_extension, int(num_str))
+
+                else:
+                    self.show_message_extension()
+            else:
+                self.show_message_num()
+        else:
+            self.show_message_folder()
+
+    def calculate_num(self, video_dir, vidoes_extension, num):
+        """
+        根据照片总张数计算切分视频的帧率
+        :param video_dir: 视频路径
+        :param vidoes_extension: 视频后缀名
+        :param num: 要得到的图片总张数
+        :return:
+        """
+        video_path_list = extension_filter(video_dir, vidoes_extension)
+        all_frames = 0
+        for video in video_path_list:
+            num_frames = cv2.VideoCapture(video).get(7)
+            all_frames += num_frames
+        if all_frames < 0:
+            self.show_message_videos_format()
+        else:
+            frame_num = int(all_frames / num)
+            self.radioButton_2.setChecked(True)
+            self.lineEdit_5.setText(str(frame_num))
+
+    def openfolder_video_path(self):
+        """
+        选择视频文件夹路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '打开文件夹', self.cwd)
+        self.lineEdit_3.setText(openfolder_path)
+        self.lineEdit_4.setText(os.path.join(openfolder_path, 'images'))
+
+    def openfolder_images_path(self):
+        """
+        选择图片文件保存路径
+        :return:
+        """
+        openfolder_path = QFileDialog.getExistingDirectory(self, '选择图片保存路径', self.cwd)
+        self.lineEdit_4.setText(openfolder_path)
+
+    def button_click(self):
+        """
+        提交时的文本框检查工作,确保必须参数有数据
+        :return:
+        """
+        if self.lineEdit.text() and self.lineEdit_2.text() and self.lineEdit_3.text() and self.lineEdit_4.text() and self.lineEdit_4.text():
+            frame_num = self.lineEdit_5.text()
+            if frame_num.isdigit() and int(frame_num) > 0:
+                video_dir = self.lineEdit_3.text().replace('\\', '/')
+                save_dir = self.lineEdit_4.text().replace('\\', '/')
+                # 视频后缀名分隔符格式化
+                if ',' in self.lineEdit.text() and ',' not in self.lineEdit.text():
+                    vidoes_extension_list = [lin for lin in self.lineEdit.text().split(',')]
+                    vidoes_extension = ','.join(set(vidoes_extension_list))
+                    Video2Image(video_dir, save_dir, vidoes_extension, self.lineEdit_2.text(),
+                                int(frame_num)).main()
+                    self.show_message_succes()
+                elif ',' not in self.lineEdit.text() and ',' not in self.lineEdit.text():
+                    vidoes_extension = self.lineEdit.text()
+                    Video2Image(video_dir, save_dir, vidoes_extension, self.lineEdit_2.text(),
+                                int(frame_num)).main()
+                    self.show_message_succes()
+
+                else:
+                    self.show_message_extension()
+            else:
+                self.show_message_num()
+        else:
+            self.show_message_isempty()
+
+    def show_message_succes(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "信息提示", "视频切分完成!")
+
+    def show_message_folder(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "空值提示", "没有获取到相应的文件夹路径!")
+
+    def show_message_num(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "数字错误提示", "输入的不是数字,或者输入的数字小于0!")
+
+    def show_message_extension(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "视频文件后缀名错误提示", "请使用半角符逗号[,]切分后缀名!")
+
+    def show_message_isempty(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "文本框有空值错误提示", "文本框不能为空,请检查!")
+
+    def show_message_videos_format(self):
+        """
+        消息提示框
+        :return:
+        """
+        QMessageBox.about(self, "视频编码错误提示", "视频有未知编码,只能使用帧数切分")
+
+
+class Video2Image:
+    def __init__(self, video_dir, save_dir, vidoes_extension, file_name, frame_num):
+        self.video_dir = video_dir  # 视频地址
+        self.save_dir = save_dir  # 保存图片地址
+        self.vidoes_extension = vidoes_extension  # 视频后缀名
+        self.file_name = file_name  # 图片文件前缀
+        self.frame_num = frame_num  # 帧数
+
+    def get_frame_split(self, frames: int, all_image_num: int):
+        """
+        当frame_num为0,即给出需要切分的图片总个数时,该方法会根据视频的总帧数除以图片总数来计算出相应的切分间隔帧数。
+
+        :param frames: 视频总帧数
+        :param all_image_num: 需要切分多少图片
+        :return: 切分间隔帧数
+        """
+        frame_split = math.ceil(frames / all_image_num)
+        if frame_split < 1:
+            sys.exit('%s//%s小于1,视频无法被正常切分。请检查all_image_num字段或者视频的总帧数是否太小' % (frames, all_image_num))
+
+        return frame_split
+
+    def gen_images(self, frame_split, vc, video_name):
+        """
+        视频切分图片,并保存在相应的文件夹中,保存图像文件的命名规则:配置前缀名+视频文件名+时间+帧数+'.jpg'
+
+        :param frame_split:需要间隔的帧数
+        :param vc:cv对象
+        :param video_name:视频文件名
+        :return:
+        """
+        time_str = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
+        video = '.'.join(video_name.split('.')[0:-1])
+
+        rval, _ = vc.read() if vc.isOpened() else False
+        c, count = 0, 0
+        str_len = len(str(int(vc.get(7))))
+        for c in tqdm(range(0, int(vc.get(7))), total=int(vc.get(7)), ncols=70):
+            rval, frame = vc.read()
+            c += 1
+            if c % frame_split == 0:
+                # rval, frame = vc.read()
+                # 每隔frame_split帧提取一张照片
+                # if rval:
+                image_name = self.file_name + '_' + video + '_' + time_str + '_' + str(c).zfill(str_len) + '.jpg'
+                image_save_name = self.save_dir + "/" + image_name
+                try:
+                    if os.path.exists(image_save_name):
+                        print('文件已存在,正在覆盖保存')
+                    cv2.imencode('.jpg', frame)[1].tofile(image_save_name)
+
+                            # cv2.imwrite(image_save_name, frame)
+                    count += 1
+                except:
+                    print('一个文件保存错误')
+                continue
+        print('%s文件总共保存了%d张图片' % (video_name, count))
+        return count
+
+
+
+
+    def gen_images_2(self, frame_split, vc, video_name):
+        """
+        针对CV无法读出视频总帧数的视频进行切分
+        视频切分图片,并保存在相应的文件夹中,保存图像文件的命名规则:配置前缀名+视频文件名+时间+帧数+'.jpg'
+
+        :param frame_split:需要间隔的帧数
+        :param vc:cv对象
+        :param video_name:视频文件名
+        :return:
+        """
+        time_str = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
+        video = '.'.join(video_name.split('.')[0:-1])
+
+        # rval, _ = vc.read() if vc.isOpened() else False
+        frame_count, count = 0, 0
+
+        with tqdm(total=5000) as pbar:
+            while (True):
+                ret, frame = vc.read()
+                if ret is False:
+                    break
+                else:
+                    frame_count = frame_count + 1
+                    if frame_count % frame_split == 0:
+                        image_name = self.file_name + '_' + video + '_' + time_str + '_' + str(frame_count) + '.jpg'
+                        image_save_name = self.save_dir + "/" + image_name
+                        pbar.update(1)
+                        pbar.set_description("切分第 %d 张图片" % count)
+                        try:
+                            if os.path.exists(image_save_name):
+                                print('文件已存在,正在覆盖保存')
+                            cv2.imencode('.jpg', frame)[1].tofile(image_save_name)
+                            count += 1
+                        except:
+                            continue
+        print('%s文件总共保存了%d张图片' % (video_name, count))
+        return count
+
+    def video_split(self, video_path):
+        """
+        获取视频文件的帧数,根据配置文件内容来判断用all_image_num还是frame_num来切分视频,然后传入gen_images()方法中
+        :param video_path:视频地址
+        :return:2w
+        """
+        video_name = video_path.split("\\")[-1].split("/")[-1]
+        vc = cv2.VideoCapture(video_path)
+        frames = int(vc.get(7))
+        print("%s 视频总帧数 %d,裁剪图片中..." % (video_name, frames))
+        # 判断使用all_image_num还是frame_num作为切分视频的变量
+        # frame_split = self.get_frame_split(frames, self.all_image_num) if self.frame_num < 1 else self.frame_num
+        frame_split = self.frame_num
+        if frames < 0:
+            count = self.gen_images_2(frame_split, vc, video_name)
+        else:
+            count = self.gen_images(frame_split, vc, video_name)
+        return count
+
+    def video_process(self):
+        """
+        遍历目录下所以符合条件的视频文件,然后传入video_split()方法
+        :return:
+        """
+        vidoes_count = 0
+        images_count = 0
+        video_path_list = extension_filter(self.video_dir, self.vidoes_extension)
+        for video_path in video_path_list:
+            images_count += self.video_split(video_path)
+            vidoes_count += 1
+        print('成功读取了%d个视频, 截取图片%d张' % (vidoes_count, images_count))
+
+    def main(self):
+
+        if not os.path.exists(self.save_dir):
+            os.makedirs(self.save_dir)
+        self.video_process()
+
+
+def extension_filter(base, extension_str):
+    """
+    提取当前目录及子目录下特定格式的文件,并返回其绝对路径
+
+    :param base: 当前目录
+    :param extension_str: 从conf文件中获取的文件扩展名
+    :return: 筛选后得到文件绝对路径的list
+    """
+    extension = extension_str.split(',')
+    fullname_list = []
+    for ex in extension:
+        ex = ex.strip() if ex.strip().startswith('.') else '.' + ex.strip()  # 扩展名补全
+        ex_list = glob.glob(base + '/**/*' + ex, recursive=True)
+        fullname_list.extend(ex_list)
+    return fullname_list
+
+
+if __name__ == "__main__":
+    app = QtWidgets.QApplication(sys.argv)
+    QMainWindow = QtWidgets.QMainWindow()
+    ui = Ui_Form()
+    ui.setupUi(QMainWindow)
+    QMainWindow.show()
+    sys.exit(app.exec_())

+ 44 - 0
code/data_manage/test_util/Qt5/videos2imgs/form1.3.spec

@@ -0,0 +1,44 @@
+# -*- mode: python ; coding: utf-8 -*-
+
+
+block_cipher = None
+
+
+a = Analysis(
+    ['form1.3.py'],
+    pathex=[],
+    binaries=[],
+    datas=[],
+    hiddenimports=[],
+    hookspath=[],
+    hooksconfig={},
+    runtime_hooks=[],
+    excludes=[],
+    win_no_prefer_redirects=False,
+    win_private_assemblies=False,
+    cipher=block_cipher,
+    noarchive=False,
+)
+pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
+
+exe = EXE(
+    pyz,
+    a.scripts,
+    a.binaries,
+    a.zipfiles,
+    a.datas,
+    [],
+    name='form1.3',
+    debug=False,
+    bootloader_ignore_signals=False,
+    strip=False,
+    upx=True,
+    upx_exclude=[],
+    runtime_tmpdir=None,
+    console=True,
+    disable_windowed_traceback=False,
+    argv_emulation=False,
+    target_arch=None,
+    codesign_identity=None,
+    entitlements_file=None,
+)

+ 28 - 0
code/data_manage/test_util/filter_file.py

@@ -0,0 +1,28 @@
+"""
+# File       : filter_file.py
+# Time       :21.6.28 10:17
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:获取某一目录下面所有子目录中前一半的视频文件
+"""
+import glob
+import shutil
+import os
+
+old_path = '/data2/share/淮南洛河电厂/20210625/巡检点/近视角'
+new_path = '/data2/fengyang/sunwin/data/image/huainan_luohedianchang/vidoes'
+if os.path.exists(new_path):
+    pass
+else:
+    os.makedirs(new_path)
+
+for dir in glob.glob(old_path+'/*'):
+    filelists = []
+    for i in glob.glob(dir+'/*.mp4'):
+        filelists.extend([i])
+    # print(filelists)
+
+    lists = filelists[0:int(len(filelists)/2)]
+    for li in lists:
+        shutil.copyfile(li, new_path+'/'+li.split('/')[-1])

+ 8 - 0
code/data_manage/test_util/find_diff/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py.py
+# Time       :21.6.24 15:28
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 111 - 0
code/data_manage/test_util/find_diff/find_different.py

@@ -0,0 +1,111 @@
+"""
+# File       : find_different.py
+# Time       :21.6.17 16:43
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import glob
+
+
+for i in glob.glob('C:\\Users\\Administrator\\Desktop\\experimental'+'\*.jpg'):
+    file_name = i.split('\\')[-1].split('.jpg')[0]+'.xml'
+    strs = """
+    <annotation>
+	<folder>experimental</folder>
+	<filename>%s</filename>
+	<path>%s</path>
+	<source>
+		<database>Unknown</database>
+	</source>
+	<size>
+		<width>1920</width>
+		<height>1080</height>
+		<depth>3</depth>
+	</size>
+	<segmented>0</segmented>
+	<object>
+		<name>F</name>
+		<pose>Unspecified</pose>
+		<truncated>0</truncated>
+		<difficult>0</difficult>
+		<bndbox>
+			<xmin>794</xmin>
+			<ymin>541</ymin>
+			<xmax>855</xmax>
+			<ymax>616</ymax>
+		</bndbox>
+	</object>
+	<object>
+		<name>T</name>
+		<pose>Unspecified</pose>
+		<truncated>0</truncated>
+		<difficult>0</difficult>
+		<bndbox>
+			<xmin>890</xmin>
+			<ymin>533</ymin>
+			<xmax>957</xmax>
+			<ymax>609</ymax>
+		</bndbox>
+	</object>
+	<object>
+		<name>F</name>
+		<pose>Unspecified</pose>
+		<truncated>0</truncated>
+		<difficult>0</difficult>
+		<bndbox>
+			<xmin>993</xmin>
+			<ymin>526</ymin>
+			<xmax>1052</xmax>
+			<ymax>605</ymax>
+		</bndbox>
+	</object>
+	<object>
+		<name>T</name>
+		<pose>Unspecified</pose>
+		<truncated>0</truncated>
+		<difficult>0</difficult>
+		<bndbox>
+			<xmin>1085</xmin>
+			<ymin>521</ymin>
+			<xmax>1152</xmax>
+			<ymax>598</ymax>
+		</bndbox>
+	</object>
+	<object>
+		<name>F</name>
+		<pose>Unspecified</pose>
+		<truncated>0</truncated>
+		<difficult>0</difficult>
+		<bndbox>
+			<xmin>1191</xmin>
+			<ymin>514</ymin>
+			<xmax>1251</xmax>
+			<ymax>589</ymax>
+		</bndbox>
+	</object>
+	<object>
+		<name>F</name>
+		<pose>Unspecified</pose>
+		<truncated>0</truncated>
+		<difficult>0</difficult>
+		<bndbox>
+			<xmin>1285</xmin>
+			<ymin>507</ymin>
+			<xmax>1346</xmax>
+			<ymax>584</ymax>
+		</bndbox>
+	</object>
+    </annotation>"""%(i.split('\\')[-1], i)
+    with open('\\'.join(i.split('\\')[:-1])+'\\'+file_name, 'w+') as f:
+        f.writelines(strs)
+
+
+
+
+
+
+
+
+

+ 8 - 0
code/data_manage/test_util/img_mask/__init__.py

@@ -0,0 +1,8 @@
+"""
+# File       : __init__.py.py
+# Time       :21.6.24 15:27
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""

+ 257 - 0
code/data_manage/test_util/img_mask/imgmask.py

@@ -0,0 +1,257 @@
+"""
+# File       : imgmask.py
+# Time       :21.6.15 9:41
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:前景图像叠加背景图像生成新图像
+"""
+
+import cv2
+import glob
+import tqdm
+import copy
+import random
+import numpy as np
+import pandas as pd
+from test_util.img_mask import pathUtil
+from xml.dom import minidom
+
+
+def cv_imread(filePath):
+    """
+    读取图像,该方法解决了cv2.imread()不能读取中文路径的问题
+    :param filePath: 文件路径
+    :return:
+    """
+    return cv2.imdecode(np.fromfile(filePath, dtype=np.uint8), -1)
+
+
+def random_size(img1, img2):
+    """
+    随机改变图片的大小, 前景图片的大小不能超过背景图片
+
+    :param img1:背景图片
+    :param img2:前景图片(特征图像)
+    :return:
+    """
+
+    img1_rows, img1_cols = img1.shape[:2]
+    img2_rows, img2_cols = img2.shape[:2]
+    while True:
+        rd_num = random.randrange(3, 20) / 10
+        rows = img2_rows * rd_num
+        cols = img1_cols * rd_num
+
+        if rows < img1_rows and cols < img1_cols:
+            img2 = cv2.resize(img2, (0, 0), fx=rd_num, fy=rd_num, interpolation=cv2.INTER_NEAREST)
+            return img2
+
+
+def random_step(bng, dst):
+    """
+    在背景图片中挑选随机位置,以用来添加特征图像
+    :param bng: 背景图片
+    :param dst: 前景图片(特征图像)
+    :return:
+    """
+    bng_rows, bng_cols = bng.shape[:2]
+    dst_rows, dst_cols = dst.shape[:2]
+    rows_step = random.randrange(0, bng_rows - 1 - dst_rows)
+    cols_step = random.randrange(0, bng_cols - 1 - dst_cols)
+    return rows_step, cols_step
+
+
+def img2mask(img1, img2):
+    img2 = random_size(img1, img2)
+    # 把logo放在左上角,所以我们只关心这一块区域
+    rows, cols = img2.shape[:2]
+    roi = img1[:rows, :cols]
+
+    # 创建掩膜
+    img2gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
+    ret, mask = cv2.threshold(img2gray, 10, 255, cv2.THRESH_BINARY)
+    mask_inv = cv2.bitwise_not(mask)
+
+    # 保留除logo外的背景
+    img1_bg = cv2.bitwise_and(roi, roi, mask=mask_inv)
+    dst = cv2.add(img1_bg, img2)  # 进行融合
+    rows_step, cols_step = random_step(img1, dst)
+    x_min, x_max = rows_step, rows + rows_step
+    y_min, y_max = cols_step, cols + cols_step
+    img1[x_min:x_max, y_min:y_max] = dst  # 融合后放在原图上
+    return img1, x_min, x_max, y_min, y_max
+
+
+def random_class_img(class_imgs_path):
+    """
+    随机挑选特征图像
+    :param class_imgs_path:
+    :return:
+    """
+    img_lists = [i for i in glob.glob(class_imgs_path + '/*.jpg')]
+    path = random.choice(img_lists)
+    return pathUtil.path_format(path)
+
+
+def write_xml(img, df_t, whd_list, total_data_dir):
+    """
+    生成xml文件,写入数据后保存
+
+    :param img:添加过特征图像后新图片的路径
+    :param df_t:该图片中特征图像的坐标信息
+    :param whd_list:新图片的长宽信息
+    :param total_data_dir:最终保存xml数据的路径
+    :return:
+    """
+    filename = img.split('/')[-1]
+
+    # 1. 创建dom树对象
+    doc = minidom.Document()
+
+    # 2. 创建根结点,并用dom对象添加根结点
+    root_node = doc.createElement("annotation")
+    doc.appendChild(root_node)
+
+    # 3. 创建结点,结点包含一个文本结点, 再将结点加入到根结点
+    folder_node = doc.createElement("folder")
+    folder_value = doc.createTextNode('ZS')
+    folder_node.appendChild(folder_value)
+    root_node.appendChild(folder_node)
+
+    filename_node = doc.createElement("filename")
+    filename_value = doc.createTextNode(filename)
+    filename_node.appendChild(filename_value)
+    root_node.appendChild(filename_node)
+
+    path_node = doc.createElement("path")
+    path_value = doc.createTextNode(img)
+    path_node.appendChild(path_value)
+    root_node.appendChild(path_node)
+
+    source_node = doc.createElement("source")
+    database_node = doc.createElement("database")
+    database_node.appendChild(doc.createTextNode("Unknown"))
+    source_node.appendChild(database_node)
+    root_node.appendChild(source_node)
+
+    size_node = doc.createElement("size")
+    for item, value in zip(["width", "height", "depth"], whd_list):
+        elem = doc.createElement(item)
+        elem.appendChild(doc.createTextNode(str(value)))
+        size_node.appendChild(elem)
+    root_node.appendChild(size_node)
+
+    seg_node = doc.createElement("segmented")
+    seg_node.appendChild(doc.createTextNode(str(0)))
+    root_node.appendChild(seg_node)
+
+    for _, df in df_t.iterrows():
+        obj_node = doc.createElement("object")
+        name_node = doc.createElement("name")
+        name_node.appendChild(doc.createTextNode(str(df['class'])))
+        obj_node.appendChild(name_node)
+
+        pose_node = doc.createElement("pose")
+        pose_node.appendChild(doc.createTextNode("Unspecified"))
+        obj_node.appendChild(pose_node)
+
+        trun_node = doc.createElement("truncated")
+        trun_node.appendChild(doc.createTextNode(str(0)))
+        obj_node.appendChild(trun_node)
+
+        trun_node = doc.createElement("difficult")
+        trun_node.appendChild(doc.createTextNode(str(0)))
+        obj_node.appendChild(trun_node)
+
+        bndbox_node = doc.createElement("bndbox")
+        for item, value in zip(["xmin", "ymin", "xmax", "ymax"], [df['xmin'], df['ymin'], df['xmax'], df['ymax']]):
+            elem = doc.createElement(item)
+            elem.appendChild(doc.createTextNode(str(value)))
+            bndbox_node.appendChild(elem)
+        obj_node.appendChild(bndbox_node)
+        root_node.appendChild(obj_node)
+
+    xml_file = filename.split('.')[0] + '.xml'
+    with open(pathUtil.path_format_join(total_data_dir, xml_file), "w", encoding="utf-8") as f:
+        # 4.writexml()第一个参数是目标文件对象,第二个参数是根节点的缩进格式,第三个参数是其他子节点的缩进格式,
+        # 第四个参数制定了换行格式,第五个参数制定了xml内容的编码。
+        doc.writexml(f, indent='', addindent='\t', newl='\n', encoding="utf-8")
+
+
+def check_class_img(df_t, df_tmp):
+    """
+    检查叠加进去的前景图片之间是否会出现重叠现象
+    df_tmp是准备添加的坐标信息,df_t中保存着所有已经存在的前景图片的坐标信息,两者进行一一对比,如果没有重叠现象通过方法
+    :param df_t: 已经添加进去的前景图片的坐标信息dataframe
+    :param df_tmp: 准备添加的前景图片的坐标信息
+    :return:
+    """
+    flag_x = False
+    flag_y = False
+    if not df_t.empty:
+        for _, df in df_t.iterrows():
+
+            if df['xmin'] < df_tmp['xmin'][0] and df['xmax'] < df_tmp['xmin'][0]:
+                flag_x = True
+            elif df['xmin'] > df_tmp['xmin'][0] and df_tmp['xmax'][0] < df['xmin']:
+                flag_x = True
+            else:
+                flag_x = False
+
+            if df['ymin'] < df_tmp['ymin'][0] and df['ymax'] < df_tmp['ymin'][0]:
+                flag_x = True
+            elif df['ymin'] > df_tmp['ymin'][0] and df_tmp['ymax'][0] < df['ymin']:
+                flag_x = True
+            else:
+                flag_y = False
+            if not flag_x or flag_y:
+                return False
+        return True
+    else:
+        return True
+
+
+def main():
+    # 裁剪下来的特征图像
+    class_img_dir = r'C:\Users\Administrator\Desktop\kuye\class_img_dir'
+    # 需要叠加特征图像的数据集
+    total_img_dir = r'C:\Users\Administrator\Desktop\test_image\test_images'
+    # 叠加特征图像后的数据集保存路径(.jpg文件和.xml文件)
+    total_data_dir = r'C:\Users\Administrator\Desktop\kuye\total_data'
+    # 特征图像的类别
+    class_list = ['a1', 'a2', 'b1', 'b2', 'c1', 'c2']
+    df = pd.DataFrame()
+    glob_list = glob.glob(total_img_dir + '/*.jpg')
+    for image in tqdm.tqdm(glob_list, total=len(glob_list), ncols=80):
+        img = pathUtil.path_format(image)
+        df_t = pd.DataFrame()
+        num = random.randint(1,4)
+        image1_name = img.split('/')[-1]
+        img1 = cv_imread(img)
+
+        for i in range(num):
+            flag = False
+            cla = random.choice(class_list)
+            class_imgs_path = pathUtil.path_format_join(class_img_dir, cla)
+            class_img_path = random_class_img(class_imgs_path)
+            img2 = cv_imread(class_img_path)
+            while not flag:
+                img1_tmp = copy.deepcopy(img1)
+                img12, xmin, xmax, ymin, ymax = img2mask(img1_tmp, img2)
+                list_tmp = [[image1_name, img1.shape[:2][0], img1.shape[:2][1], cla, xmin, ymin, xmax, ymax]]
+                df_tmp = pd.DataFrame(list_tmp,
+                                      columns=['filename', 'width', 'height', 'class', 'ymin', 'xmin', 'ymax', 'xmax'])
+                flag = check_class_img(df_t, df_tmp)
+
+            img1 = img12
+            df_t = pd.concat([df_t, df_tmp])
+
+        write_xml(img, df_t, img1.shape, total_data_dir)
+        df = pd.concat([df, df_t])
+        img_path = pathUtil.path_format_join(total_data_dir, image1_name)
+        cv2.imwrite(img_path, img1)
+
+
+# 运行程序
+main()

+ 33 - 0
code/data_manage/test_util/img_mask/pathUtil.py

@@ -0,0 +1,33 @@
+"""
+# File       : pathUtil.py
+# Time       :21.5.25 18:13
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:路径操作工具
+"""
+import os
+import shutil
+from gen_data.utils import strUtil
+
+def mkdir(new_folder):
+    if not os.path.exists(new_folder):
+        os.mkdir(new_folder)
+
+def mkdir_new(new_folder):
+    if os.path.exists(new_folder):
+        shutil.rmtree(new_folder)
+        os.mkdir(new_folder)
+    os.mkdir(new_folder)
+
+def path_format(path_str):
+    path = strUtil.profile2str(path_str.replace('\\','/'))
+    if str(path).endswith('/'):
+        return str(path)[0:-1]
+    else:
+        return path
+
+def path_format_join(path_str1, path_str2):
+    return os.path.join(path_format(path_str1), path_format(path_str2)).replace('\\','/')
+
+

+ 29 - 0
code/data_manage/test_util/line2point/draw_line.py

@@ -0,0 +1,29 @@
+"""
+# File       : test.py
+# Time       :21.10.19 8:54
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import cv2
+
+import glob
+path = r"C:\Users\Administrator\Desktop\ICCV\ICCV2017_JTLEE_gtlines_all"
+path2 = r"C:\Users\Administrator\Desktop\ICCV\ICCV2017_JTLEE_images"
+for i in glob.glob(path2+'/*.jpg'):
+    i_file = path+"/"+i.split("\\")[-1].replace('.jpg', '.txt')
+    i_txt = open(i_file, 'r').readlines()
+    t_list = []
+    for txt in i_txt:
+        t = txt.strip().split(', ')[:-2]
+        t_list.append(t)
+    img= cv2.imread(i, cv2.COLOR_BGR2RGB)
+    cv2.line(img, (int(float(t_list[0][0])), int(float(t_list[0][1]))), (int(float(t_list[0][2])), int(float(t_list[0][3]))), (0,255,0), thickness=1)
+    cv2.line(img, (int(float(t_list[1][0])), int(float(t_list[1][1]))),
+             (int(float(t_list[1][2])), int(float(t_list[1][3]))), (0, 255, 0), thickness=1)
+    cv2.imwrite(path+'/'+i.split("\\")[-1], img)
+
+
+
+

+ 270 - 0
code/data_manage/test_util/line2point/get_convert_belt_info(2).py

@@ -0,0 +1,270 @@
+import math
+from sklearn.cluster import KMeans
+import cv2
+import numpy as np
+
+
+class ConvertBeltInfo:
+    def __init__(self, hough_lines_thread=150):
+        self.hough_lines_thread = hough_lines_thread
+
+    @staticmethod
+    def _get_line(line, original_img):
+        rho = line[0]  # 第一个元素是距离rho
+        theta = line[1]
+        point_list = []
+        if (theta < (np.pi / 4.)) or (theta > (3. * np.pi / 4.0)):  # 垂直直线
+            pt1 = (int(rho / np.cos(theta)), 0)  # 该直线与第一行的交点
+            point_list.extend([pt1])
+            # 该直线与最后一行的焦点
+            pt2 = (int((rho - original_img.shape[0] * np.sin(theta)) / np.cos(theta)), original_img.shape[0])
+            point_list.extend([pt2])
+
+        else:  # 水平直线
+            pt1 = (0, int(rho / np.sin(theta)))  # 该直线与第一列的交点
+            point_list.extend([pt1])
+            # 该直线与最后一列的交点
+            pt2 = (original_img.shape[1], int((rho - original_img.shape[1] * np.cos(theta)) / np.sin(theta)))
+            point_list.extend([pt2])
+
+        k = (point_list[1][1] - point_list[0][1]) / (point_list[1][0] - point_list[0][0])
+        b = point_list[0][1] - k * point_list[0][0]
+        return k, b, point_list
+
+    @staticmethod
+    def _get_point(line, image_shape):
+        width = image_shape[1]
+        height = image_shape[0]
+        point_img = []
+        if 0 <= (line[0] * 0 + line[1]) <= height:
+            point_img.extend([(0, int(line[0] * 0 + line[1]))])
+
+        if 0 <= (line[0] * width + line[1]) <= height:
+            point_img.extend([(width, int(line[0] * width + line[1]))])
+
+        if 0 <= (0 - line[1] + float('1e-8')) / (line[0] + float('1e-8')) <= width:
+            point_img.extend([(int((0 - line[1] + float('1e-8')) / (line[0] + float('1e-8'))), 0)])
+
+        if 0 <= (height - line[1] + float('1e-8')) / (line[0] + float('1e-8')) <= width:
+            point_img.extend([(int((height - line[1] + float('1e-8')) / (line[0] + float('1e-8'))), height)])
+        return point_img
+
+    @staticmethod
+    def _list_remove(lists, a):
+        if a in lists:
+            lists.remove(a)
+
+    def _max_length(self, lines, image):
+        lines_point_list = []
+        for line in lines:
+            _, _, point_list = self._get_line(line[0], image)  # 得到直线与图片两条高线的相交点
+            lines_point_list.extend(point_list)
+        lines_point_np_array = np.array(lines_point_list)
+        kmeans = KMeans(n_clusters=4, random_state=0).fit(lines_point_np_array)
+        point_0_list = []
+        point_1_list = []
+        point_2_list = []
+        point_3_list = []
+
+        for i in range(len(kmeans.labels_)):
+            if kmeans.labels_[i] == 0:
+                point_0_list.append(lines_point_list[i])
+            elif kmeans.labels_[i] == 1:
+                point_1_list.append(lines_point_list[i])
+            elif kmeans.labels_[i] == 2:
+                point_2_list.append(lines_point_list[i])
+            elif kmeans.labels_[i] == 3:
+                point_3_list.append(lines_point_list[i])
+        four_points_list = []
+
+        points_line_x_0_list = []
+        points_line_x_width_list = []
+        point_0 = np.mean(np.array(point_0_list), axis=0)
+        point_1 = np.mean(np.array(point_1_list), axis=0)
+        point_2 = np.mean(np.array(point_2_list), axis=0)
+        point_3 = np.mean(np.array(point_3_list), axis=0)
+
+        four_points_list.append([int(p) for p in point_0])
+        four_points_list.append([int(p) for p in point_1])
+        four_points_list.append([int(p) for p in point_2])
+        four_points_list.append([int(p) for p in point_3])
+        for point in four_points_list:
+            if point[0] == 0:
+                points_line_x_0_list.append(point)
+            elif point[0] == image.shape[1]:
+                points_line_x_width_list.append(point)
+
+        points_line_x_0_list.sort()
+        points_line_x_width_list.sort()
+        line1 = [points_line_x_0_list[0], points_line_x_width_list[0]]
+        line2 = [points_line_x_0_list[-1], points_line_x_width_list[-1]]
+        k1 = (line1[1][1] - line1[0][1]) / (line1[1][0] - line1[0][0])
+        b1 = line1[0][1] - k1 * line1[0][0]
+        k2 = (line2[1][1] - line2[0][1]) / (line2[1][0] - line2[0][0])
+        b2 = line2[0][1] - k2 * line2[0][0]
+
+        lines_list = [[k1, b1], [k2, b2]]
+
+
+        return lines_list
+
+    @staticmethod
+    def draw_convert_belt_line_mask(max_len_lines, ori_image, image_mask):
+        image_mask = image_mask / 255 * np.array([111, 222, 80])
+        for line in max_len_lines:
+            rho = line[0]  # 第一个元素是距离rho
+            theta = line[1]  # 第二个元素是角度theta
+            if (theta < (np.pi / 4.)) or (theta > (3. * np.pi / 4.0)):  # 垂直直线
+                pt1 = (int(rho / np.cos(theta)), 0)  # 该直线与第一行的交点
+                # 该直线与最后一行的焦点
+                pt2 = (int((rho - ori_image.shape[0] * np.sin(theta)) / np.cos(theta)), ori_image.shape[0])
+                ori_image = cv2.line(ori_image, pt1, pt2, (0, 0, 255), thickness=5)  # 绘制一条白线
+            else:  # 水平直线
+                pt1 = (0, int(rho / np.sin(theta)))  # 该直线与第一列的交点
+                # 该直线与最后一列的交点
+                pt2 = (ori_image.shape[1], int((rho - ori_image.shape[1] * np.cos(theta)) / np.sin(theta)))
+                ori_image = cv2.line(ori_image, pt1, pt2, (0, 0, 255), thickness=5)  # 绘制一条直线
+        result_image = cv2.addWeighted(ori_image, 0.3, image_mask, 0.7, 0, dtype=cv2.CV_32F)
+        return result_image
+
+    @staticmethod
+    def get_convert_belt_angle(k1, k2):
+        convert_belt_angle = math.fabs(np.arctan((k1 - k2) / (float(1 + k1 * k2))) * 180 / np.pi)
+        return convert_belt_angle
+
+    @staticmethod
+    def _lines_up_or_up_down(max_len_line1, max_len_line2):
+        top_down_label_dict = {}
+        if max_len_line1[0] <= max_len_line2[1]:
+            top_down_label_dict = {"top": list(max_len_line2), "down": list(max_len_line1)}
+        elif max_len_line1[0] > max_len_line2[1]:
+            top_down_label_dict = {"top": list(max_len_line1), "down": list(max_len_line2)}
+
+        return top_down_label_dict
+
+    def _polar_coordinates_function_2_cartesian_coordinates_point(self, line, image):
+
+        k, b, _ = self._get_line(line, image)
+        point = self._get_point((k, b), image.shape)
+        return point
+
+    def label_dict_to_point(self, top_down_label_dict, image):
+        for k in top_down_label_dict:
+            line = top_down_label_dict.get(k)
+            point = self._polar_coordinates_function_2_cartesian_coordinates_point(line, image)
+            top_down_label_dict[k] = point
+
+        return top_down_label_dict
+
+    def get_bbox(self, point_1, point_2, image_shape, k1, k2):
+        result_order = []
+        line_y_0 = []
+        line_x_1920 = []
+        line_y_1080 = []
+        line_x_0 = []
+
+        positive = [(0, 0), (image_shape[0], image_shape[1])]
+        negative = [(0, image_shape[1]), (image_shape[0], 0)]
+        top_list = [(0, 0), (0, image_shape[1]), (image_shape[0], image_shape[1]), (image_shape[0], 0)]
+        point_list = point_1.copy()
+        point_list.extend(point_2)
+        xs = [x for x, _ in point_list]
+        ys = [y for _, y in point_list]
+
+        if xs.count(0) == 2:
+            self._list_remove(top_list, (0, 0))
+            self._list_remove(top_list, (0, image_shape[1]))
+        if xs.count(image_shape[0]) == 2:
+            self._list_remove(top_list, (image_shape[0], image_shape[1]))
+            self._list_remove(top_list, (image_shape[0], 0))
+
+        if ys.count(0) == 2:
+            self._list_remove(top_list, (0, 0))
+            self._list_remove(top_list, (image_shape[0], 0))
+
+        if ys.count(image_shape[1]) == 2:
+            self._list_remove(top_list, (0, image_shape[1]))
+            self._list_remove(top_list, (image_shape[0], image_shape[1]))
+
+        result = point_list
+        if len(top_list) == 0:
+            pass
+        elif len(top_list) == 2:
+            if k1 >= 0 and k2 >= 0:
+                result.extend(list(set(top_list) & set(positive)))
+            elif k1 < 0 and k2 < 0:
+                result.extend(list(set(top_list) & set(negative)))
+
+        elif len(top_list) == 4:
+            if k1 >= 0 and k2 >= 0:
+                result.extend(list(set(top_list) & set(positive)))
+            elif k1 < 0 and k2 < 0:
+                result.extend(list(set(top_list) & set(negative)))
+        for re in result:
+            if re[1] == 0:
+                line_y_0.append(re)
+            elif re[0] == 1920:
+                line_x_1920.append(re)
+            elif re[1] == 1080:
+                line_y_1080.append(re)
+            elif re[0] == 0:
+                line_x_0.append(re)
+        result_order.extend(sorted(line_y_0, key=lambda line: line[0]))
+        result_order.extend(sorted(line_x_1920, key=lambda line: line[1]))
+        result_order.extend(sorted(line_y_1080, key=lambda line: line[0], reverse=True))
+        result_order.extend(sorted(line_x_0, key=lambda line: line[1], reverse=True))
+
+        return result_order
+
+    def get_top_down_lines(self, image):
+        image = cv2.GaussianBlur(image, (3, 3), 0)
+        edges = cv2.Canny(image, 50, 150, apertureSize=3)
+        lines = cv2.HoughLines(edges, 1, np.pi / 180, 150)  # 这里对最后一个参数使用了经验型的值
+        max_len_lines = self._max_length(lines, image)
+        return max_len_lines
+
+    def get_lines_point(self, max_len_lines, image):
+
+        k1, b1, _ = self._get_line(max_len_lines[0], image)
+        k2, b2, _ = self._get_line(max_len_lines[1], image)
+        point_1 = self._get_point((k1, b1), image.shape)
+        point_2 = self._get_point((k2, b2), image.shape)
+        return point_1, point_2, k1, k2, b1, b2
+
+
+
+
+    def get_result(self, ori_image, image_mask):
+        max_len_lines = self.get_top_down_lines(image_mask)
+        point_1, point_2, k1, k2, b1, b2 = self.get_lines_point(max_len_lines, image_mask)
+        result_order = self.get_bbox(point_1, point_2, image_mask.shape, k1, k2)
+        convert_belt_angle = self.get_convert_belt_angle(k1, k2)
+        top_down_label_dict = self._lines_up_or_up_down(max_len_lines[0], max_len_lines[1])
+        top_down_label_dict = self.label_dict_to_point(top_down_label_dict, image_mask)
+        two_lines_k_b = [(k1, b1), (k2, b2)]
+        result_image = self.draw_convert_belt_line_mask(max_len_lines, ori_image, image_mask)
+
+        return two_lines_k_b, result_order, top_down_label_dict, convert_belt_angle, result_image
+
+    def draw_image_rectangular_coordinates(self, ori_image, image_mask):
+        """
+        直角坐标画图
+        """
+        max_len_lines, result_order, top_down_label_dict, _, _ = self.get_result(ori_image, image_mask)
+        print(max_len_lines)
+        for line in top_down_label_dict:
+
+            point = self._get_point(line, image_mask.shape)
+            image_mask = cv2.line(image_mask, point[0], point[1], (0, 0, 255), thickness=1)
+        cv2.imshow('draw_image', image_mask)
+        cv2.waitKey()
+
+        cv2.imwrite("line_1.jpg", image_mask)
+
+
+
+if __name__ == "__main__":
+    image = cv2.imread('huainanluohe_pidai_192.168.1.122_01_2021062415321311_202107082019_3450.png')
+    cb = ConvertBeltInfo()
+    cb.draw_image_rectangular_coordinates(image, image)
+    # cb.draw_image_rectangular_coordinates(image)

+ 332 - 0
code/data_manage/test_util/line2point/hough_line.py

@@ -0,0 +1,332 @@
+"""
+# File       : hough_line.py
+# Time       :21.7.13 13:36
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import math
+import cv2
+from sklearn.cluster import KMeans
+import numpy as np
+
+
+class ConvertBeltInfo:
+    def __init__(self, hough_lines_thread=150):
+        self.hough_lines_thread = hough_lines_thread
+
+    @staticmethod
+    def _get_line(line, original_img):
+        rho = line[0]  # 第一个元素是距离rho
+        theta = line[1]
+        point_list = []
+        if (theta < (np.pi / 4.)) or (theta > (3. * np.pi / 4.0)):  # 垂直直线
+            pt1 = (int(rho / np.cos(theta)), 0)  # 该直线与第一行的交点
+            point_list.extend([pt1])
+            # 该直线与最后一行的焦点
+            pt2 = (int((rho - original_img.shape[0] * np.sin(theta)) / np.cos(theta)), original_img.shape[0])
+            point_list.extend([pt2])
+
+        else:  # 水平直线
+            pt1 = (0, int(rho / np.sin(theta)))  # 该直线与第一列的交点
+            point_list.extend([pt1])
+            # 该直线与最后一列的交点
+            pt2 = (original_img.shape[1], int((rho - original_img.shape[1] * np.cos(theta)) / np.sin(theta)))
+            point_list.extend([pt2])
+
+        k = (point_list[1][1] - point_list[0][1]) / (point_list[1][0] - point_list[0][0])
+        b = point_list[0][1] - k * point_list[0][0]
+        return k, b
+
+    @staticmethod
+    def _get_point(line, image_shape):
+        width = image_shape[1]
+        height = image_shape[0]
+        point_img = []
+        if 0 <= (line[0] * 0 + line[1]) <= height:
+            point_img.extend([(0, int(line[0] * 0 + line[1]))])
+
+        if 0 <= (line[0] * width + line[1]) <= height:
+            point_img.extend([(width, int(line[0] * width + line[1]))])
+
+        if 0 <= (0 - line[1] + float('1e-8')) / (line[0] + float('1e-8')) <= width:
+            point_img.extend([(int((0 - line[1] + float('1e-8')) / (line[0] + float('1e-8'))), 0)])
+
+        if 0 <= (height - line[1] + float('1e-8')) / (line[0] + float('1e-8')) <= width:
+            point_img.extend([(int((height - line[1] + float('1e-8')) / (line[0] + float('1e-8'))), height)])
+        return point_img
+
+    @staticmethod
+    def _list_remove(lists, a):
+        if a in lists:
+            lists.remove(a)
+
+    @staticmethod
+    def _max_length(lines):
+        max_len = 0
+        max_len_index1 = -1
+        max_len_index2 = -1
+        print(len(lines))
+        for i in range(len(lines)):
+            num = i + 1
+            for j in range(num, len(lines)):
+                length = abs(lines[i][0][0] - lines[j][0][0])  # obtain max bias with two lines
+                if max_len < length:
+                    max_len = length
+                    max_len_index1 = i
+                    max_len_index2 = j
+
+        result = []
+        result.extend(lines[max_len_index1])
+        result.extend(lines[max_len_index2])
+        return result
+
+    @staticmethod
+    def _get_line_2(line, original_img):
+        rho = line[0]  # 第一个元素是距离rho
+        theta = line[1]
+        point_list = []
+        if (theta < (np.pi / 4.)) or (theta > (3. * np.pi / 4.0)):  # 垂直直线
+            pt1 = (int(rho / np.cos(theta)), 0)  # 该直线与第一行的交点
+            point_list.extend([pt1])
+            # 该直线与最后一行的焦点
+            pt2 = (int((rho - original_img.shape[0] * np.sin(theta)) / np.cos(theta)), original_img.shape[0])
+            point_list.extend([pt2])
+
+        else:  # 水平直线
+            pt1 = (0, int(rho / np.sin(theta)))  # 该直线与第一列的交点
+            point_list.extend([pt1])
+            # 该直线与最后一列的交点
+            pt2 = (original_img.shape[1], int((rho - original_img.shape[1] * np.cos(theta)) / np.sin(theta)))
+            point_list.extend([pt2])
+        return point_list
+
+
+    def _max_length_3(self, lines, image):
+        lines_point_list = []
+        for line in lines:
+            point_list = self._get_line_2(line[0], image) #得到直线与图片两条高线的相交点
+            lines_point_list.extend(point_list)
+        lines_point_np_array = np.array(lines_point_list)
+        kmeans = KMeans(n_clusters=4, random_state=0).fit(lines_point_np_array)
+        point_0_list = []
+        point_1_list = []
+        point_2_list = []
+        point_3_list = []
+
+        for i in range(len(kmeans.labels_)):
+            if kmeans.labels_[i] == 0:
+                point_0_list.append(lines_point_list[i])
+            elif kmeans.labels_[i] == 1:
+                point_1_list.append(lines_point_list[i])
+            elif kmeans.labels_[i] == 2:
+                point_2_list.append(lines_point_list[i])
+            elif kmeans.labels_[i] == 3:
+                point_3_list.append(lines_point_list[i])
+        four_points_list = []
+
+
+
+        points_line_x_0_list = []
+        points_line_x_width_list = []
+        point_0 = np.mean(np.array(point_0_list), axis=0)
+        point_1 = np.mean(np.array(point_1_list), axis=0)
+        point_2 = np.mean(np.array(point_2_list), axis=0)
+        point_3 = np.mean(np.array(point_3_list), axis=0)
+
+
+
+
+        four_points_list.append([int(p) for p in point_0])
+        four_points_list.append([int(p) for p in point_1])
+        four_points_list.append([int(p) for p in point_2])
+        four_points_list.append([int(p) for p in point_3])
+        for point in four_points_list:
+            if point[0] == 0:
+                points_line_x_0_list.append(point)
+            elif point[0] == image.shape[1]:
+                points_line_x_width_list.append(point)
+
+        points_line_x_0_list.sort()
+        points_line_x_width_list.sort()
+
+
+
+        # points_line_x_0_list = []
+        # points_line_x_width_list = []
+        # for point in lines_point_list:
+        #     if point[0] == 0:
+        #         points_line_x_0_list.append(point)
+        #     elif point[0] == image.shape[1]:
+        #         points_line_x_width_list.append(point)
+        # points_line_x_0_list.sort()
+        # points_line_x_width_list.sort()
+        # # sorted(points_line_x_width_list,key=takeSecond)
+        # print(points_line_x_0_list)
+        # print(points_line_x_width_list)
+        # print('line1_:',[points_line_x_0_list[0],points_line_x_width_list[0]])
+        # print('line1_',[points_line_x_0_list[-1], points_line_x_width_list[-1]])
+        cv2.line(image, points_line_x_0_list[0],points_line_x_width_list[0], (0, 0, 255), thickness=1)
+        cv2.line(image, points_line_x_0_list[-1], points_line_x_width_list[-1], (0, 0, 255), thickness=1)
+        cv2.imshow('_max_length_2', image)
+        cv2.imwrite("_max_length_2.jpg", image)
+        cv2.waitKey()
+        # input()
+
+    @staticmethod
+    def get_convert_belt_angle(k1, k2):
+        convert_belt_angle = math.fabs(np.arctan((k1 - k2) / (float(1 + k1 * k2))) * 180 / np.pi)
+        return convert_belt_angle
+
+    @staticmethod
+    def _lines_up_or_up_down(max_len_line1, max_len_line2):
+        top_down_label_dict = {}
+        if max_len_line1[0] <= max_len_line2[1]:
+            top_down_label_dict = {1: list(max_len_line2), 2: list(max_len_line1)}
+        elif max_len_line1[0] > max_len_line2[1]:
+            top_down_label_dict = {1: list(max_len_line1), 2: list(max_len_line2)}
+
+        return top_down_label_dict
+
+    def get_bbox(self, point_1, point_2, image_shape, k1, k2):
+        result_order = []
+        line_y_0 = []
+        line_x_1920 = []
+        line_y_1080 = []
+        line_x_0 = []
+
+        positive = [(0, 0), (image_shape[0], image_shape[1])]
+        negative = [(0, image_shape[1]), (image_shape[0], 0)]
+        top_list = [(0, 0), (0, image_shape[1]), (image_shape[0], image_shape[1]), (image_shape[0], 0)]
+        point_list = point_1.copy()
+        point_list.extend(point_2)
+        xs = [x for x, _ in point_list]
+        ys = [y for _, y in point_list]
+
+        if xs.count(0) == 2:
+            self._list_remove(top_list, (0, 0))
+            self._list_remove(top_list, (0, image_shape[1]))
+        if xs.count(image_shape[0]) == 2:
+            self._list_remove(top_list, (image_shape[0], image_shape[1]))
+            self._list_remove(top_list, (image_shape[0], 0))
+
+        if ys.count(0) == 2:
+            self._list_remove(top_list, (0, 0))
+            self._list_remove(top_list, (image_shape[0], 0))
+
+        if ys.count(image_shape[1]) == 2:
+            self._list_remove(top_list, (0, image_shape[1]))
+            self._list_remove(top_list, (image_shape[0], image_shape[1]))
+
+        result = point_list
+        if len(top_list) == 0:
+            pass
+        elif len(top_list) == 2:
+            if k1 >= 0 and k2 >= 0:
+                result.extend(list(set(top_list) & set(positive)))
+            elif k1 < 0 and k2 < 0:
+                result.extend(list(set(top_list) & set(negative)))
+
+        elif len(top_list) == 4:
+            if k1 >= 0 and k2 >= 0:
+                result.extend(list(set(top_list) & set(positive)))
+            elif k1 < 0 and k2 < 0:
+                result.extend(list(set(top_list) & set(negative)))
+        for re in result:
+            if re[1] == 0:
+                line_y_0.append(re)
+            elif re[0] == 1920:
+                line_x_1920.append(re)
+            elif re[1] == 1080:
+                line_y_1080.append(re)
+            elif re[0] == 0:
+                line_x_0.append(re)
+        result_order.extend(sorted(line_y_0, key=lambda line: line[0]))
+        result_order.extend(sorted(line_x_1920, key=lambda line: line[1]))
+        result_order.extend(sorted(line_y_1080, key=lambda line: line[0], reverse=True))
+        result_order.extend(sorted(line_x_0, key=lambda line: line[1], reverse=True))
+
+        return result_order
+
+
+
+    def get_top_down_lines(self, image):
+        image = cv2.GaussianBlur(image, (3, 3), 0)
+        edges = cv2.Canny(image, 50, 150, apertureSize=3)
+        lines = cv2.HoughLines(edges, 1, np.pi / 180, 60)  # 这里对最后一个参数使用了经验型的值
+        # max_len_lines = self._max_length(lines)
+        # max_len_lines = self._max_length_2(lines, image)
+        max_len_lines = self._max_length_3(lines, image)
+        return max_len_lines
+
+    def get_lines_point(self, max_len_lines, image):
+
+        k1, b1 = self._get_line(max_len_lines[0], image)
+        k2, b2 = self._get_line(max_len_lines[1], image)
+        point_1 = self._get_point((k1, b1), image.shape)
+        point_2 = self._get_point((k2, b2), image.shape)
+        return point_1, point_2, k1, k2
+
+    def get_result(self, image):
+        max_len_lines = self.get_top_down_lines(image)
+        point_1, point_2, k1, k2 = self.get_lines_point(max_len_lines, image)
+        result_order = self.get_bbox(point_1, point_2, image.shape, k1, k2)
+        convert_belt_angle = self.get_convert_belt_angle(k1, k2)
+        print("angle", convert_belt_angle)
+        top_down_label_dict = self._lines_up_or_up_down(max_len_lines[0], max_len_lines[1])
+        return max_len_lines, result_order, top_down_label_dict, convert_belt_angle
+
+
+
+
+    def draw_image(self, image):
+        image = cv2.GaussianBlur(image, (3, 3), 0)
+        edges = cv2.Canny(image, 50, 100, apertureSize=3)
+        lines = cv2.HoughLines(edges, 1, np.pi / 180, 60)
+        # print(lines)
+        # cv2.imshow('e', edges)
+        # cv2.waitKey(0)
+        max_len_lines = self._max_length(lines)
+        for lines in max_len_lines:
+            line = lines[0]
+            rho = line[0]  # 第一个元素是距离rho
+            theta = line[1]  # 第二个元素是角度theta
+            if (theta < (np.pi / 4.)) or (theta > (3. * np.pi / 4.0)):  # 垂直直线
+                pt1 = (int(rho / np.cos(theta)), 0)  # 该直线与第一行的交点
+                # 该直线与最后一行的焦点
+                pt2 = (int((rho - image.shape[0] * np.sin(theta)) / np.cos(theta)), image.shape[0])
+                image = cv2.line(image, pt1, pt2, (255, 0, 0), thickness=1)  # 绘制一条白线
+                print('line2_',[pt1, pt2])
+            else:  # 水平直线
+                pt1 = (0, int(rho / np.sin(theta)))  # 该直线与第一列的交点
+                # 该直线与最后一列的交点
+                pt2 = (image.shape[1], int((rho - image.shape[1] * np.cos(theta)) / np.sin(theta)))
+                image = cv2.line(image, pt1, pt2, (255, 0, 0), thickness=1)  # 绘制一条直线
+                print('line2_', [pt1, pt2])
+
+        cv2.imshow('draw_image', image)
+        cv2.waitKey()
+
+        # cv2.imwrite("line_ALL.jpg", image)
+
+
+if __name__ =="__main__":
+    image = cv2.imread('huainanluohe_pidai_192.168.1.122_01_20210624145235484_202107082018_2600.png')
+    # cb = ConvertBeltInfo()
+    # cb.draw_image(image)
+    # cb.get_result(image)
+
+
+
+
+# if __name__ == "__main__":
+#     hl = hough_line(r'huainanluohe_pidai_192.168.1.122_01_2021062415321311_202107082019_3450.png')
+#     # hl.get_result()
+#     hl.img_view()
+
+    # bboxs=[[139, 494, 190, 552], [360, 401, 415, 465], [544, 307, 606, 377], [716, 233, 771, 301], [865, 159, 930, 229]
+    #     , [1030, 973, 1097, 1038], [1181, 856, 1248, 945], [1342, 737, 1416, 830], [1473, 598, 1555, 722], [1621, 466, 1696, 573], [1750, 343, 1829, 449]]
+    #
+    #
+    #

+ 279 - 0
code/data_manage/test_util/line2point/line2point.py

@@ -0,0 +1,279 @@
+"""
+# File       : line2point.py
+# Time       :21.7.8 10:13
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import glob
+import json
+import os
+import shutil
+import cv2
+from scipy.spatial import Delaunay
+import numpy as np
+import pandas as pd
+
+def mkdir(new_folder):
+    if not os.path.exists(new_folder):
+        os.makedirs(new_folder)
+
+def path_format(path_str):
+    path = path_str.replace('\\','/')
+    if str(path).endswith('/'):
+        return '/'.join(path[0:-1])
+    else:
+        return path
+
+def get_line(point_list):
+    k = (point_list[1][1] - point_list[0][1])/(point_list[1][0] - point_list[0][0])
+    b = point_list[0][1] - k*point_list[0][0]
+    return (k,b)
+
+
+def get_point(line, imgshape):
+    width = imgshape[0]
+    height = imgshape[1]
+    point_img = []
+    if 0 <= (line[0]*0 +line[1]) <= height:
+        point_img.extend([(0, int(line[0]*0 +line[1]))])
+
+    if 0 <= (line[0]*width +line[1]) <= height:
+        point_img.extend([(width, int(line[0] * width + line[1]))])
+
+    if 0 <= (0-line[1]+float('1e-8'))/(line[0]+float('1e-8')) <= width:
+        point_img.extend([(int((0-line[1]+float('1e-8'))/(line[0]+float('1e-8'))), 0)])
+
+    if 0 <= (height-line[1]+float('1e-8'))/(line[0]+float('1e-8')) <= width:
+        point_img.extend([(int((height-line[1]+float('1e-8'))/(line[0]+float('1e-8'))), height)])
+    return point_img
+
+def list_remove(lists, a):
+    if a in lists:
+        lists.remove(a)
+    else:
+        print('list:',lists)
+
+def get_bbox(point_1, point_2, imgshape, k1,k2):
+    result_order = []
+    line_y_0 = []
+    line_x_1920 = []
+    line_y_1080 = []
+    line_x_0 = []
+
+    positive = [(0,0), (imgshape[0], imgshape[1])]
+    negative = [(0, imgshape[1]),(imgshape[0], 0)]
+    top_list = [(0,0),(0, imgshape[1]),(imgshape[0], imgshape[1]), (imgshape[0], 0)]
+    point_list = point_1.copy()
+    point_list.extend(point_2)
+    xs = [x for x,_ in point_list]
+    ys = [y for _, y in point_list]
+
+    if xs.count(0) == 2:
+        list_remove(top_list, (0,0))
+        list_remove(top_list, (0, imgshape[1]))
+    if xs.count(imgshape[0]) == 2:
+        list_remove(top_list, (imgshape[0], imgshape[1]))
+        list_remove(top_list, (imgshape[0], 0))
+
+
+    if ys.count(0) == 2:
+        list_remove(top_list, (0,0))
+        list_remove(top_list, (imgshape[0], 0))
+
+    if ys.count(imgshape[1]) == 2:
+        list_remove(top_list, (0, imgshape[1]))
+        list_remove(top_list, (imgshape[0], imgshape[1]))
+
+    result = point_list
+    if len(top_list) == 0:
+        pass
+        # return result
+    elif len(top_list) == 2:
+        if k1 >= 0 and k2 >=0:
+            result.extend(list(set(top_list) & set(positive)))
+        elif k1 < 0 and k2 < 0:
+            result.extend(list(set(top_list) & set(negative)))
+
+    elif len(top_list) == 4:
+        if k1 >= 0 and k2 >= 0:
+            result.extend(list(set(top_list) & set(positive)))
+        elif k1 < 0 and k2 < 0:
+            result.extend(list(set(top_list) & set(negative)))
+    for re in result:
+        if re[1] == 0:
+           line_y_0.append(re)
+        elif re[0] == imgshape[0]:
+            line_x_1920.append(re)
+        elif re[1] == imgshape[1]:
+            line_y_1080.append(re)
+        elif re[0] == 0:
+            line_x_0.append(re)
+    result_order.extend(sorted(line_y_0, key=lambda line: line[0]))
+    result_order.extend(sorted(line_x_1920, key=lambda line: line[1]))
+    result_order.extend(sorted(line_y_1080, key=lambda line: line[0], reverse=True))
+    result_order.extend(sorted(line_x_0, key=lambda line: line[1], reverse=True))
+
+    return result_order
+
+def main(path):
+    result_list = []
+    for jso in glob.glob(path + '/*.json'):
+        line_1 = []
+        line_2 = []
+        with open(jso) as f:
+            jso_read = json.load(f)
+        line = jso_read['shapes'][0:2]
+        imgshape = (jso_read['imageWidth'], jso_read['imageHeight'])
+        for li in line:
+            if li['label'] == '1' and li['shape_type'] == 'line' and len(li['points']) == 2:
+                line_1 = li['points']
+            elif li['label'] == '2' and li['shape_type'] == 'line' and len(li['points']) == 2:
+                line_2 = li['points']
+            else:
+                print('错误的标注文件。')
+                new_path = path_format(path) +'/redundant_data'
+                mkdir(new_path)
+                shutil.move(jso, new_path+'/'+ jso.split('/')[-1].split('\\')[-1])
+                jpg_name = jso.replace('.json', '.jpg')
+                shutil.move(jpg_name, new_path + '/' + jpg_name.split('/')[-1].split('\\')[-1])
+                break
+        if line_1==[] or line_2 ==[]:
+            continue
+        if line_1[0] <= line_2[0] and line_1[1] >= line_2[1]:
+            print('类别标注错误的标注文件。')
+            new_path = path_format(path) + '/redundant_data'
+            mkdir(new_path)
+            shutil.move(jso, new_path + '/' + jso.split('/')[-1].split('\\')[-1])
+            jpg_name = jso.replace('.json','.jpg')
+            # print(jpg_name)
+            shutil.move(jpg_name, new_path + '/' + jpg_name.split('/')[-1].split('\\')[-1])
+            continue
+
+        line_1 = get_line(line_1)
+        line_2 = get_line(line_2)
+        point_img_1 = get_point(line_1, imgshape)
+        point_img_2 = get_point(line_2, imgshape)
+        try:
+            result = get_bbox(point_img_1, point_img_2, imgshape, line_1[0], line_2[0])
+        except:
+            print('标注图像有问题,请检查')
+            new_path = path_format(path) + '/redundant_data'
+            mkdir(new_path)
+            shutil.move(jso, new_path + '/' + jso.split('/')[-1].split('\\')[-1])
+            jpg_name = jso.replace('.json', '.jpg')
+            # print(jpg_name)
+            shutil.move(jpg_name, new_path + '/' + jpg_name.split('/')[-1].split('\\')[-1])
+            continue
+        jpg_file = jso.replace('.json', '.jpg')
+        img = cv2.imdecode(np.fromfile(jpg_file, dtype=np.uint8), -1)
+        points = []
+        for single_result in result:
+            points.append([single_result[0], single_result[1]])
+        result_np = np.array(points, np.int32)
+        img_background = np.zeros_like(img, dtype = np.uint8)
+        cv2.fillConvexPoly(img_background, result_np, (0, 255, 0))
+        # img_background = cv2.cvtColor(img_background, cv2.COLOR_GRAY2RGB)
+
+        mask_path = '/'.join(path_format(jso).split('/')[0:-1])+'/mask'
+        mkdir(mask_path)
+        file_name = path_format(jso).split('/')[-1].replace('.json', '.png')
+        mask_img = mask_path+'/'+file_name
+        print(img_background.shape)
+        print(file_name)
+        cv2.imencode('.jpg', img_background)[1].tofile(mask_img)
+        new_path = path_format(path) + '/img_data'
+        mkdir(new_path)
+        jpg_name = jso.replace('.json', '.jpg')
+        # print(jpg_name)
+        shutil.move(jpg_name, new_path + '/' + jpg_name.split('/')[-1].split('\\')[-1])
+        result_list.extend([[file_name,str(result).replace('[', '').replace(']', '').replace('"', '')]])
+    df = pd.DataFrame(result_list, columns=[['filename', 'points']])
+    scv_path = path+'/results.csv'
+    df.to_csv(scv_path, encoding='utf8', index=False)
+
+
+
+path = r"C:\Users\Administrator\Desktop\7.12\7.12\淮南洛河电厂——皮带"
+main(path)
+
+
+
+
+# points = [(0, 774), (1539, 0), (1920, 20), (783, 1080), (0, 1080), (1920, 0)]
+#
+# pts = np.array(points).astype(np.int32)
+# points = pts
+# import matplotlib.pyplot as plt
+# from scipy.spatial import Delaunay
+#
+# def alpha_shape(points, alpha, only_outer=True):
+#     """
+#     Compute the alpha shape (concave hull) of a set of points.
+#     :param points: np.array of shape (n,2) points.
+#     :param alpha: alpha value.
+#     :param only_outer: boolean value to specify if we keep only the outer border
+#     or also inner edges.
+#     :return: set of (i,j) pairs representing edges of the alpha-shape. (i,j) are
+#     the indices in the points array.
+#     """
+#     assert points.shape[0] > 3, "Need at least four points"
+#
+#     def add_edge(edges, i, j):
+#         """
+#         Add an edge between the i-th and j-th points,
+#         if not in the list already
+#         """
+#         if (i, j) in edges or (j, i) in edges:
+#             # already added
+#             assert (j, i) in edges, "Can't go twice over same directed edge right?"
+#             if only_outer:
+#                 # if both neighboring triangles are in shape, it's not a boundary edge
+#                 edges.remove((j, i))
+#             return
+#         edges.add((i, j))
+#
+#     tri = Delaunay(points)
+#
+#     edges = set()
+#     # Loop over triangles:
+#     # ia, ib, ic = indices of corner points of the triangle
+#
+#     for ia, ib, ic in tri.vertices:
+#         pa = points[ia]
+#         pb = points[ib]
+#         pc = points[ic]
+#         # Computing radius of triangle circumcircle
+#         # www.mathalino.com/reviewer/derivation-of-formulas/derivation-of-formula-for-radius-of-circumcircle
+#         a = np.sqrt((pa[0] - pb[0]) ** 2 + (pa[1] - pb[1]) ** 2)
+#         b = np.sqrt((pb[0] - pc[0]) ** 2 + (pb[1] - pc[1]) ** 2)
+#         c = np.sqrt((pc[0] - pa[0]) ** 2 + (pc[1] - pa[1]) ** 2)
+#         s = (a + b + c) / 2.0
+#         area = np.sqrt(s * (s - a) * (s - b) * (s - c))
+#         circum_r = a * b * c / (4.0 * area)
+#
+#         if circum_r < alpha:
+#             add_edge(edges, ia, ib)
+#             add_edge(edges, ib, ic)
+#             add_edge(edges, ic, ia)
+#     return edges
+
+
+# Computing the alpha shape
+# 通过这里的alpha阈值,可以得到不同的外接多边形。阈值选的不好,可能得不到外接多边形。比如选的太小。
+# edges = alpha_shape(points, alpha=2000, only_outer=True)
+# print(edges)
+#
+# # Plotting the output
+# pag = np.zeros_like(image)
+# fig, ax = plt.subplots(figsize=(6,4))
+# ax.axis('equal')
+# plt.plot(points[:, 0], points[:, 1], '.',color='b')
+# for i, j in edges:
+#     print(points[[i, j], 0], points[[i, j], 1])
+#     ax.plot(points[[i, j], 0], points[[i, j], 1], color='red')
+#     pass
+# ax.invert_yaxis()
+# plt.show()
+

+ 260 - 0
code/data_manage/test_util/line2point/linePointAngleUtil.py

@@ -0,0 +1,260 @@
+"""
+# File       : linePointAngleUtil.py
+# Time       :21.7.16 14:02
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import math
+import numpy as np
+import cv2
+
+from PIL import Image, ImageDraw, ImageFont
+
+
+def _calc_abc_from_line_2d(point1, point2):
+    x0, y0 = point1
+    x1, y1 = point2
+    a = y0 - y1
+    b = x1 - x0
+    c = x0 * y1 - x1 * y0
+    return a, b, c
+
+
+def _get_line_cross_point(line1, line2):
+    # x1y1x2y2
+    a0, b0, c0 = _calc_abc_from_line_2d(line1[0], line1[1])
+    a1, b1, c1 = _calc_abc_from_line_2d(line2[0], line2[1])
+    D = a0 * b1 - a1 * b0
+    if D == 0:
+        return None, None
+    x = (b0 * c1 - b1 * c0) / D
+    y = (a1 * c0 - a0 * c1) / D
+    return x, y
+
+
+def k_b_get_point(line, image_shape):
+    """
+    通过直线方程和图片大小确定该直线与图像边缘相交的点坐标
+    :param line: 直线方程的斜率和截距
+    :type line: (k,b)
+    :param image_shape: 图像宽高
+    :type image_shape: (h,w)
+    :return: 交点坐标信息
+    :rtype:[(x,y),(x,y)]
+    """
+    height = image_shape[0]
+    width = image_shape[1]
+    point_img = []
+    if 0 <= (line[0] * 0 + line[1]) <= height:
+        point_img.extend([(0, int(line[0] * 0 + line[1]))])
+
+    if 0 <= (line[0] * width + line[1]) <= height:
+        point_img.extend([(width, int(line[0] * width + line[1]))])
+
+    if 0 <= (0 - line[1] + float('1e-8')) / (line[0] + float('1e-8')) <= width:
+        point_img.extend([(int((0 - line[1] + float('1e-8')) / (line[0] + float('1e-8'))), 0)])
+
+    if 0 <= (height - line[1] + float('1e-8')) / (line[0] + float('1e-8')) <= width:
+        point_img.extend([(int((height - line[1] + float('1e-8')) / (line[0] + float('1e-8'))), height)])
+    return point_img
+
+
+def lines_point_to_angle_bisector(line1_points, line2_points, image_shape):
+    """
+    通过两条直线的点坐标信息求出两条直线的角平分线与图像边缘的焦点
+    :param line1_points: 直线1的坐标点集合
+    :type line1_points: [(x,y),(x,y)]
+    :param line2_points: 直线2的坐标点集合
+    :type line2_points: [(x,y),(x,y)]
+    :param image_shape: 图像的高宽信息
+    :type image_shape: [w,h]
+    :return: 角平分线与图像边缘相交的点坐标
+    :rtype:[(x,y),(x,y)]
+    """
+    k1 = (line1_points[1][1] - line1_points[0][1]) / (line1_points[1][0] - line1_points[0][0])
+    k2 = (line2_points[1][1] - line2_points[0][1]) / (line2_points[1][0] - line2_points[0][0])
+    inv_k1 = np.arctan(k1)
+    inv_k2 = np.arctan(k2)
+    k = np.tan((inv_k1 + inv_k2) / 2)
+    x, y = _get_line_cross_point(line1_points, line2_points)
+    if x == None or y == None:
+        return None
+    b = y - k * x
+    line_points_list = k_b_get_point((k, b), image_shape)
+    return line_points_list
+
+
+def points_to_area(points):
+    """
+    根据顺时针/逆时针顺序的点坐标集合去计算闭合图形面积
+    :param points: 点坐标集合
+    :type points: [(x,y),(x,y),(x,y),.....]
+    :return: 面积
+    :rtype: float
+    """
+    s = 0
+    point_num = len(points)
+    if point_num >= 3:
+        s = points[0][1] * (points[point_num - 1][0] - points[1][0])
+        for i in range(1, point_num):
+            s += (points[i][1] * (points[i - 1][0] - points[(i + 1) % point_num][0])) / 2
+    return s
+
+
+def points_to_line_k_b(point_list):
+    """
+    根据点坐标求直线方程
+    :param point_list: 直线点坐标
+    :type point_list: [(x,y),(x,y)]
+    :return: 直线方程参数
+    :rtype: (k,b)
+    """
+    k = (point_list[1][1] - point_list[0][1]) / (point_list[1][0] - point_list[0][0])
+    b = point_list[0][1] - k * point_list[0][0]
+    result = (k, b)
+    return result
+
+
+def lines_k_to_convert_belt_angle(k1, k2):
+    """
+    根据两条直线的斜率求其夹角
+    :param k1: 斜率1
+    :type k1: float
+    :param k2: 斜率2
+    :type k2: float
+    :return: 夹角
+    :rtype: float
+    """
+    convert_belt_angle = math.fabs(np.arctan((k1 - k2) / (float(1 + k1 * k2))) * 180 / np.pi)
+    return convert_belt_angle
+
+
+def cv2ImgAddText(img, text, left, top, textColor=(255, 0, 0), textSize=20):
+    """
+    图片添加文本信息
+    :param img:图像
+    :type img:
+    :param text: 文本
+    :type text: str
+    :param left: x轴
+    :type left: int
+    :param top: y轴
+    :type top: int
+    :param textColor: RGB
+    :type textColor:(r, g, b)
+    :param textSize: 字体大小
+    :type textSize: int
+    :return:
+    :rtype:
+    """
+    # 判断是否为opencv图片类型
+    if (isinstance(img, np.ndarray)):
+        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+    draw = ImageDraw.Draw(img)
+    fontText = ImageFont.truetype('simsun.ttc', textSize, encoding="utf-8")
+    draw.text((left, top), text, textColor, font=fontText)
+    return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
+
+def _list_remove(lists, a):
+    if a in lists:
+        lists.remove(a)
+    else:
+        print('list:',lists)
+
+def get_bbox(point_1, point_2, width, high):
+
+    result_order = []
+    line_y_0 = []
+    line_x_width = []
+    line_y_high = []
+    line_x_0 = []
+
+    k1, _ = points_to_line_k_b(point_1)
+    k2, _ = points_to_line_k_b(point_2)
+    positive = [(0, 0), (width, high)]
+    negative = [(0, high), (width, 0)]
+    top_list = [(0, 0), (0, high), (width, high), (width, 0)]
+    point_list = point_1.copy()
+    point_list.extend(point_2)
+    xs = [x for x, _ in point_list]
+    ys = [y for _, y in point_list]
+
+    if xs.count(0) == 2:
+        _list_remove(top_list, (0, 0))
+        _list_remove(top_list, (0, high))
+    if xs.count(width) == 2:
+        _list_remove(top_list, (width, high))
+        _list_remove(top_list, (width, 0))
+
+    if ys.count(0) == 2:
+        _list_remove(top_list, (0, 0))
+        _list_remove(top_list, (width, 0))
+
+    if ys.count(high) == 2:
+        _list_remove(top_list, (0, high))
+        _list_remove(top_list, (width, high))
+
+    result = point_list
+    if len(top_list) == 0:
+        pass
+    elif len(top_list) == 2:
+        if k1 >= 0 and k2 >= 0:
+            result.extend(list(set(top_list) & set(positive)))
+        elif k1 < 0 and k2 < 0:
+            result.extend(list(set(top_list) & set(negative)))
+
+    elif len(top_list) == 4:
+        if k1 >= 0 and k2 >= 0:
+            result.extend(list(set(top_list) & set(positive)))
+        elif k1 < 0 and k2 < 0:
+            result.extend(list(set(top_list) & set(negative)))
+    for re in result:
+        if re[1] == 0:
+            line_y_0.append(re)
+        elif re[0] == width:
+            line_x_width.append(re)
+        elif re[1] == high:
+            line_y_high.append(re)
+        elif re[0] == 0:
+            line_x_0.append(re)
+    result_order.extend(sorted(line_y_0, key=lambda line: line[0]))
+    result_order.extend(sorted(line_x_width, key=lambda line: line[1]))
+    result_order.extend(sorted(line_y_high, key=lambda line: line[0], reverse=True))
+    result_order.extend(sorted(line_x_0, key=lambda line: line[1], reverse=True))
+
+    return result_order
+
+
+
+
+def images_iou(mask_image, line1_point, line2_point):
+    img_gray = cv2.cvtColor(mask_image, cv2.COLOR_RGB2GRAY)
+    result_order = get_bbox(line1_point, line2_point, img_gray.shape[0], img_gray.shape[1])
+    img_fun = np.zeros_like(img_gray, np.uint8)
+    img_fun = cv2.fillPoly(img_fun, [np.array(result_order)], 255)
+    img_new = img_gray*img_fun/(img_gray+ img_fun + 1e-5)
+
+    cv2.imshow('img', img_new)
+    cv2.waitKey()
+
+
+
+
+    # img_contour, contours, hierarchy = cv2.findContours(img_gray, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
+    # print(img_gray.shape)
+    # print('白色像素%s个'%len(img_gray[img_gray==255]))
+    # print('黑色像素%s个' % len(img_gray[img_gray == 0]))
+img = cv2.imread('huainanluohe_pidai_192.168.1.122_01_2021062415321311_202107082019_0850.png')
+images_iou(img, [(0,782), (1510,0)], [(819,1080), (1920,140)])
+
+
+
+# image = cv2.imread('huainanluohe_pidai_192.168.1.122_01_20210624145235484_202107082018_2600.png')
+# img = cv2ImgAddText(image, 'niushc你吃的uygasycgsga20', 0, 0)
+#
+# cv2.imshow("photo", img)
+#
+# cv2.waitKey(0)
+# cv2.destroyAllWindows()

+ 48 - 0
code/data_manage/test_util/merger_tfrecord/merger_tfrecord.py

@@ -0,0 +1,48 @@
+"""
+# File       : merger_tfrecord.py
+# Time       :21.7.6 14:14
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:将两个tfrecord文件合并为一个
+"""
+import os
+import tensorflow.compat.v1 as tf
+
+def tf_write_tf(path, writer):
+    ex = tf.python_io.tf_record_iterator(path)
+    while True:
+        try:
+            e = next(ex)
+        except:
+            break
+        writer.write(e)
+
+def main():
+    while True:
+        path1 = input('一、请输入需要合并的两个tfrecord文件中的其中一个路径:')
+        if os.path.exists(path1) and path1.endswith('.record'):
+            break
+        else:
+            print('文件不存在或者后缀名不是record,请重新输入')
+    while True:
+        path2 = input('二、请输入另一个tfrecord文件路径:')
+        if os.path.exists(path2) and path2.endswith('.record'):
+            break
+        else:
+            print('文件不存在或者后缀名不是record,请重新输入')
+    while True:
+        path3 = input('二、请输入合并后的新文件的名称(含路径):')
+        if path3.endswith('.record'):
+            break
+        else:
+            print('文件后缀名不是record,请重新输入')
+
+    writer = tf.python_io.TFRecordWriter(path3)
+    tf_write_tf(path1, writer)
+    tf_write_tf(path2, writer)
+    writer.close()
+
+main()
+
+

+ 65 - 0
code/data_manage/test_util/other2jpg/other2jpg.py

@@ -0,0 +1,65 @@
+"""
+# File       : other2jpg.py
+# Time       :21.6.23 10:29
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:将
+"""
+
+import glob
+from PIL import Image
+import os
+from gen_data.utils import pathUtil
+import xml.etree.ElementTree as ET
+
+
+def convert(source, tmp_dir_path):
+    """
+
+    :param source:
+    :param target:
+    :return:
+    """
+    im = Image.open(source)
+    file_ = os.path.splitext(source)[0]
+    name_ = os.path.splitext(source)[0].split('/')[-1]
+    xmlfile = file_+'.xml'
+    if os.path.exists(xmlfile):
+        dom = ET.parse(xmlfile)
+        root = dom.getroot()
+        root.find('filename').text = name_+ '.jpg'
+        root.find('path').text = os.path.splitext(root.find('path').text)[0] + '.jpg'
+
+        # 保存到指定文件
+        xmlfile_path = pathUtil.path_format_join(tmp_dir_path, os.path.splitext(source)[0].split('/')[-1]) + '.xml'
+        dom.write(xmlfile_path, xml_declaration=True)
+
+
+    jpgpath = pathUtil.path_format_join(tmp_dir_path, os.path.splitext(source)[0].split('/')[-1]) + '.jpg'
+
+    im.save(jpgpath)
+
+def main():
+
+    supports = ['bmp', 'dib', 'gif', 'tif', 'tiff', 'jfif', 'jpe', 'jpeg', 'pbm', 'pgm', 'ppm',
+                'pnm', 'png', 'apng', 'pcx', 'ps', 'eps', 'jp2', 'j2k', 'jpc', 'jpf', 'jpx', 'j2c', 'ico', 'im',
+                'mpo', 'pdf', 'bw', 'rgb', 'rgba', 'sgi', 'tga', 'icb', 'vda', 'vst', 'webp', 'jpg']
+    flag = True
+    while flag:
+        path = input('请输入需要转换图片格式目录路径(输入0结束程序):')
+        dir_path = pathUtil.path_format(path)
+        if dir_path == '0':
+            exit(0)
+        elif os.path.exists(dir_path):
+            count = 0
+            tmp_dir_path = pathUtil.path_format_join(dir_path, '_temp_imgs')
+            pathUtil.mkdir_new(tmp_dir_path)
+            for support in supports:
+                for file_path in glob.glob(dir_path + '/*.' + support):
+                    convert(pathUtil.path_format(file_path), tmp_dir_path)
+                    count += 1
+            print('%d张图片转换完成,请在 %s 目录下查看' % (count, tmp_dir_path))
+            exit(0)
+        else:
+            print('路径不存在,请重新输入')

+ 39 - 0
code/data_manage/test_util/replace_color/parse_xml.py

@@ -0,0 +1,39 @@
+"""
+# File       : replace_color.py
+# Time       :21.6.30 15:01
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:根据name属性批量删除xml文件中的object
+"""
+
+import glob
+import os
+import xml.etree.ElementTree as ET
+
+
+
+def xml_parse(data_dir, save_path, lists):
+    """
+
+    :param data_dir: xml file path
+    :return: dataframe (filename, path, width, height, depth, class_name, x_min, y_min, x_max, y_max)
+    """
+    xml_list = []
+    for xml_file in glob.glob(data_dir + '/*.xml'):
+        print(xml_file)
+        tree = ET.parse(xml_file)
+        root = tree.getroot()
+        for member in root.findall('object'):
+            if member.find('name').text not in lists:
+                member.find('name').text = 'other'
+        tree.write(save_path+'/'+xml_file.split('/')[-1].split('\\')[-1], encoding='utf8')
+
+
+# def main():
+lists = ['tree']
+dir_path = r'C:\Users\Administrator\Desktop\新建文件夹\semantic_drone_dataset\training_set\gt\semantic\label_me_xml'
+save_path = r'C:\Users\Administrator\Desktop\新建文件夹\semantic_drone_dataset\training_set\gt\semantic\label_me_xml\new'
+if not os.path.exists(save_path):
+    os.makedirs(save_path)
+xml_parse(dir_path, save_path, lists)

+ 70 - 0
code/data_manage/test_util/replace_color/replace_color.py

@@ -0,0 +1,70 @@
+"""
+# File       : replace_color.py
+# Time       :21.6.30 15:01
+# Author     :FEANGYANG
+# version    :python 3.7
+# Contact    :1071082183@qq.com
+# Description:
+"""
+import os
+
+from PIL import Image
+import glob
+import numpy as np
+import time
+import shutil
+
+
+def replace_color(img, src_clr, dst_clr):
+    ''' 通过矩阵操作颜色替换程序
+    @param	img:	图像矩阵
+    @param	src_clr:	需要替换的颜色(r,g,b)
+    @param	dst_clr:	目标颜色		(r,g,b)
+    @return				替换后的图像矩阵
+    '''
+
+
+    img_arr = np.asarray(img, dtype=np.double)
+
+    # 分离通道
+    r_img = img_arr[:, :, 0].copy()
+    g_img = img_arr[:, :, 1].copy()
+    b_img = img_arr[:, :, 2].copy()
+
+    # 编码
+    img = r_img * 256 * 256 + g_img * 256 + b_img
+    src_color = src_clr[0] * 256 * 256 + src_clr[1] * 256 + src_clr[2]
+
+    # 索引并替换颜色
+    r_img[img == src_color] = dst_clr[0]
+    g_img[img == src_color] = dst_clr[1]
+    b_img[img == src_color] = dst_clr[2]
+
+    # 合并通道
+    dst_img = np.array([r_img, g_img, b_img], dtype=np.uint8)
+    # 将数据转换为图像数据(h,w,c)
+    dst_img = dst_img.transpose(1, 2, 0)
+
+    return dst_img
+
+def main(new_path, old_path, color_list, new_color):
+    if not os.path.exists(new_path):
+        os.makedirs(new_path)
+    for file in glob.glob(old_path+'/*.png'):
+        print(file)
+        img = Image.open(file).convert('RGB')
+        dst_img = img.copy()
+        for color in color_list:
+            dst_img = replace_color(dst_img, color, new_color)
+        # dst_img = replace_color(dst_img, (51,51,0), (0,128,0))
+        res_img = Image.fromarray(dst_img)
+        res_img.save(new_path+'/'+file.split('\\')[-1].split('/')[-1])
+
+color_list = [(0,102,0), (112,103,87), (28,42,168), (255,0,0), (2,135,115), (112,150,146), (190,250,190), (119,11,32), (9,143,150), (102,51,0), (255,22,96),
+              (153,153,153), (190,153,153), (254,148,12), (254,228,12), (102,102,156), (70,70,70), (107,142,35), (0,50,89), (48,41,30), (130,76,0), (128,64,128),(28,42,168
+
+)]
+new_color = (0,0,0)
+new_path = r'C:\Users\Administrator\Desktop\新建文件夹\img2\new_label'
+old_path = r'C:\Users\Administrator\Desktop\新建文件夹\img2\label'
+main(new_path, old_path, color_list, new_color)

+ 51 - 0
code/object_detection.sh

@@ -0,0 +1,51 @@
+#!/bin/bash
+
+str=$1
+# str是外部传入的实参,如:sudo ./object_detection.sh/data2/fengyang/sunwin/data/image/项目名称/
+# <存放预处理图片的路径  ,只需定位到total_data的上一级目录就行>
+
+if [ ! -n "$str" ]; then     # 利用 -n来判断字符串是否非空
+  echo 'path 值为空'
+  exit 0
+
+fi
+
+if [ ! -d $str ]; then   # -d 文件为目录为真
+  echo '$str 不存在,请检查路径是否存在...'
+  exit 0
+fi
+
+# 注意:路径后面不要带'/'符号
+#train_model_path='/data2/fengyang/sunwin/train_model'
+#train_code_path='/data2/fengyang/sunwin/code/models-master/research/object_detection'
+data_path='/data2/object_detection/data/image'  # 项目名称/存放total_data的路径
+data_manage_code_path='/data2/object_detection/code/data_manage'  # 数据预处理data_manage的路径
+yolov5_code_path='/data2/object_detection/code/yolov5'
+
+if [[ $str == */ ]]; then project_path=${str%?};else project_path=$str;fi
+
+str_name=${project_path##*/}   # ${变量名##删除字符串 正则表达式},作用:从字符串的结尾开始匹配 删除字符串
+
+project_class_img_path="${project_path}/class_img_dir"
+
+echo '333'
+sed -i "31c dir = $project_path" $data_manage_code_path/config/data_manage.conf  # sed -i 直接对内容进行修改
+
+sed -n '30,31p' $data_manage_code_path/config/data_manage.conf  # 取消默认输出,sed默认会输出所有文本内容,
+                                                                # 使用-n参数后只显示处理过的行,打印30到31行
+echo '111'
+
+cd $data_manage_code_path
+python run_gen.py 
+echo '222'
+
+
+echo '等待三分钟...。如果不需要执行模型训练功能,请尽快关闭脚本并kill其进程'
+sleep 30s 
+
+
+cd $yolov5_code_path
+
+
+python train.py --data /data2/object_detection/data/image/$str_name/yolo/$str_name.yaml --weights weights/yolov5l6.pt --img 1280  --epochs 70 --name $str_name --batch-size 4
+

+ 301 - 0
code/yolov5/README.md

@@ -0,0 +1,301 @@
+<div align="center">
+<p>
+   <a align="left" href="https://ultralytics.com/yolov5" target="_blank">
+   <img width="850" src="https://github.com/ultralytics/yolov5/releases/download/v1.0/splash.jpg"></a>
+</p>
+<br>
+<div>
+   <a href="https://github.com/ultralytics/yolov5/actions"><img src="https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg" alt="CI CPU testing"></a>
+   <a href="https://zenodo.org/badge/latestdoi/264818686"><img src="https://zenodo.org/badge/264818686.svg" alt="YOLOv5 Citation"></a>
+   <a href="https://hub.docker.com/r/ultralytics/yolov5"><img src="https://img.shields.io/docker/pulls/ultralytics/yolov5?logo=docker" alt="Docker Pulls"></a>
+   <br>
+   <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
+   <a href="https://www.kaggle.com/ultralytics/yolov5"><img src="https://kaggle.com/static/images/open-in-kaggle.svg" alt="Open In Kaggle"></a>
+   <a href="https://join.slack.com/t/ultralytics/shared_invite/zt-w29ei8bp-jczz7QYUmDtgo6r6KcMIAg"><img src="https://img.shields.io/badge/Slack-Join_Forum-blue.svg?logo=slack" alt="Join Forum"></a>
+</div>
+
+<br>
+<p>
+YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset, and represents <a href="https://ultralytics.com">Ultralytics</a>
+ open-source research into future vision AI methods, incorporating lessons learned and best practices evolved over thousands of hours of research and development.
+</p>
+
+<div align="center">
+   <a href="https://github.com/ultralytics">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-github.png" width="2%"/>
+   </a>
+   <img width="2%" />
+   <a href="https://www.linkedin.com/company/ultralytics">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-linkedin.png" width="2%"/>
+   </a>
+   <img width="2%" />
+   <a href="https://twitter.com/ultralytics">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-twitter.png" width="2%"/>
+   </a>
+   <img width="2%" />
+   <a href="https://www.producthunt.com/@glenn_jocher">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-producthunt.png" width="2%"/>
+   </a>
+   <img width="2%" />
+   <a href="https://youtube.com/ultralytics">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-youtube.png" width="2%"/>
+   </a>
+   <img width="2%" />
+   <a href="https://www.facebook.com/ultralytics">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-facebook.png" width="2%"/>
+   </a>
+   <img width="2%" />
+   <a href="https://www.instagram.com/ultralytics/">
+   <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-instagram.png" width="2%"/>
+   </a>
+</div>
+
+<!--
+<a align="center" href="https://ultralytics.com/yolov5" target="_blank">
+<img width="800" src="https://github.com/ultralytics/yolov5/releases/download/v1.0/banner-api.png"></a>
+-->
+
+</div>
+
+## <div align="center">Documentation</div>
+
+See the [YOLOv5 Docs](https://docs.ultralytics.com) for full documentation on training, testing and deployment.
+
+## <div align="center">Quick Start Examples</div>
+
+<details open>
+<summary>Install</summary>
+
+Clone repo and install [requirements.txt](https://github.com/ultralytics/yolov5/blob/master/requirements.txt) in a
+[**Python>=3.7.0**](https://www.python.org/) environment, including
+[**PyTorch>=1.7**](https://pytorch.org/get-started/locally/).
+
+```bash
+git clone https://github.com/ultralytics/yolov5  # clone
+cd yolov5
+pip install -r requirements.txt  # install
+```
+
+</details>
+
+<details open>
+<summary>Inference</summary>
+
+Inference with YOLOv5 and [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36)
+. [Models](https://github.com/ultralytics/yolov5/tree/master/models) download automatically from the latest
+YOLOv5 [release](https://github.com/ultralytics/yolov5/releases).
+
+```python
+import torch
+
+# Model
+model = torch.hub.load('ultralytics/yolov5', 'yolov5s')  # or yolov5m, yolov5l, yolov5x, custom
+
+# Images
+img = 'https://ultralytics.com/images/zidane.jpg'  # or file, Path, PIL, OpenCV, numpy, list
+
+# Inference
+results = model(img)
+
+# Results
+results.print()  # or .show(), .save(), .crop(), .pandas(), etc.
+```
+
+</details>
+
+<details>
+<summary>Inference with detect.py</summary>
+
+`detect.py` runs inference on a variety of sources, downloading [models](https://github.com/ultralytics/yolov5/tree/master/models) automatically from
+the latest YOLOv5 [release](https://github.com/ultralytics/yolov5/releases) and saving results to `runs/detect`.
+
+```bash
+python detect.py --source 0  # webcam
+                          img.jpg  # image
+                          vid.mp4  # video
+                          path/  # directory
+                          path/*.jpg  # glob
+                          'https://youtu.be/Zgi9g1ksQHc'  # YouTube
+                          'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP stream
+```
+
+</details>
+
+<details>
+<summary>Training</summary>
+
+The commands below reproduce YOLOv5 [COCO](https://github.com/ultralytics/yolov5/blob/master/data/scripts/get_coco.sh)
+results. [Models](https://github.com/ultralytics/yolov5/tree/master/models)
+and [datasets](https://github.com/ultralytics/yolov5/tree/master/data) download automatically from the latest
+YOLOv5 [release](https://github.com/ultralytics/yolov5/releases). Training times for YOLOv5n/s/m/l/x are
+1/2/4/6/8 days on a V100 GPU ([Multi-GPU](https://github.com/ultralytics/yolov5/issues/475) times faster). Use the
+largest `--batch-size` possible, or pass `--batch-size -1` for
+YOLOv5 [AutoBatch](https://github.com/ultralytics/yolov5/pull/5092). Batch sizes shown for V100-16GB.
+
+```bash
+python train.py --data coco.yaml --cfg yolov5n.yaml --weights '' --batch-size 128
+                                       yolov5s                                64
+                                       yolov5m                                40
+                                       yolov5l                                24
+                                       yolov5x                                16
+```
+
+<img width="800" src="https://user-images.githubusercontent.com/26833433/90222759-949d8800-ddc1-11ea-9fa1-1c97eed2b963.png">
+
+</details>
+
+<details open>
+<summary>Tutorials</summary>
+
+- [Train Custom Data](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data)  🚀 RECOMMENDED
+- [Tips for Best Training Results](https://github.com/ultralytics/yolov5/wiki/Tips-for-Best-Training-Results)  ☘️
+  RECOMMENDED
+- [Weights & Biases Logging](https://github.com/ultralytics/yolov5/issues/1289)  🌟 NEW
+- [Roboflow for Datasets, Labeling, and Active Learning](https://github.com/ultralytics/yolov5/issues/4975)  🌟 NEW
+- [Multi-GPU Training](https://github.com/ultralytics/yolov5/issues/475)
+- [PyTorch Hub](https://github.com/ultralytics/yolov5/issues/36)  ⭐ NEW
+- [TFLite, ONNX, CoreML, TensorRT Export](https://github.com/ultralytics/yolov5/issues/251) 🚀
+- [Test-Time Augmentation (TTA)](https://github.com/ultralytics/yolov5/issues/303)
+- [Model Ensembling](https://github.com/ultralytics/yolov5/issues/318)
+- [Model Pruning/Sparsity](https://github.com/ultralytics/yolov5/issues/304)
+- [Hyperparameter Evolution](https://github.com/ultralytics/yolov5/issues/607)
+- [Transfer Learning with Frozen Layers](https://github.com/ultralytics/yolov5/issues/1314)  ⭐ NEW
+- [Architecture Summary](https://github.com/ultralytics/yolov5/issues/6998)  ⭐ NEW
+
+</details>
+
+## <div align="center">Environments</div>
+
+Get started in seconds with our verified environments. Click each icon below for details.
+
+<div align="center">
+    <a href="https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-colab-small.png" width="15%"/>
+    </a>
+    <a href="https://www.kaggle.com/ultralytics/yolov5">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-kaggle-small.png" width="15%"/>
+    </a>
+    <a href="https://hub.docker.com/r/ultralytics/yolov5">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-docker-small.png" width="15%"/>
+    </a>
+    <a href="https://github.com/ultralytics/yolov5/wiki/AWS-Quickstart">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-aws-small.png" width="15%"/>
+    </a>
+    <a href="https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-gcp-small.png" width="15%"/>
+    </a>
+</div>
+
+## <div align="center">Integrations</div>
+
+<div align="center">
+    <a href="https://wandb.ai/site?utm_campaign=repo_yolo_readme">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-wb-long.png" width="49%"/>
+    </a>
+    <a href="https://roboflow.com/?ref=ultralytics">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-roboflow-long.png" width="49%"/>
+    </a>
+</div>
+
+|Weights and Biases|Roboflow ⭐ NEW|
+|:-:|:-:|
+|Automatically track and visualize all your YOLOv5 training runs in the cloud with [Weights & Biases](https://wandb.ai/site?utm_campaign=repo_yolo_readme)|Label and export your custom datasets directly to YOLOv5 for training with [Roboflow](https://roboflow.com/?ref=ultralytics) |
+
+<!-- ## <div align="center">Compete and Win</div>
+
+We are super excited about our first-ever Ultralytics YOLOv5 🚀 EXPORT Competition with **$10,000** in cash prizes!
+
+<p align="center">
+  <a href="https://github.com/ultralytics/yolov5/discussions/3213">
+  <img width="850" src="https://github.com/ultralytics/yolov5/releases/download/v1.0/banner-export-competition.png"></a>
+</p> -->
+
+## <div align="center">Why YOLOv5</div>
+
+<p align="left"><img width="800" src="https://user-images.githubusercontent.com/26833433/155040763-93c22a27-347c-4e3c-847a-8094621d3f4e.png"></p>
+<details>
+  <summary>YOLOv5-P5 640 Figure (click to expand)</summary>
+
+<p align="left"><img width="800" src="https://user-images.githubusercontent.com/26833433/155040757-ce0934a3-06a6-43dc-a979-2edbbd69ea0e.png"></p>
+</details>
+<details>
+  <summary>Figure Notes (click to expand)</summary>
+
+- **COCO AP val** denotes mAP@0.5:0.95 metric measured on the 5000-image [COCO val2017](http://cocodataset.org) dataset over various inference sizes from 256 to 1536.
+- **GPU Speed** measures average inference time per image on [COCO val2017](http://cocodataset.org) dataset using a [AWS p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) V100 instance at batch-size 32.
+- **EfficientDet** data from [google/automl](https://github.com/google/automl) at batch size 8.
+- **Reproduce** by `python val.py --task study --data coco.yaml --iou 0.7 --weights yolov5n6.pt yolov5s6.pt yolov5m6.pt yolov5l6.pt yolov5x6.pt`
+
+</details>
+
+### Pretrained Checkpoints
+
+|Model |size<br><sup>(pixels) |mAP<sup>val<br>0.5:0.95 |mAP<sup>val<br>0.5 |Speed<br><sup>CPU b1<br>(ms) |Speed<br><sup>V100 b1<br>(ms) |Speed<br><sup>V100 b32<br>(ms) |params<br><sup>(M) |FLOPs<br><sup>@640 (B)
+|---                    |---  |---    |---    |---    |---    |---    |---    |---
+|[YOLOv5n][assets]      |640  |28.0   |45.7   |**45** |**6.3**|**0.6**|**1.9**|**4.5**
+|[YOLOv5s][assets]      |640  |37.4   |56.8   |98     |6.4    |0.9    |7.2    |16.5
+|[YOLOv5m][assets]      |640  |45.4   |64.1   |224    |8.2    |1.7    |21.2   |49.0
+|[YOLOv5l][assets]      |640  |49.0   |67.3   |430    |10.1   |2.7    |46.5   |109.1
+|[YOLOv5x][assets]      |640  |50.7   |68.9   |766    |12.1   |4.8    |86.7   |205.7
+|                       |     |       |       |       |       |       |       |
+|[YOLOv5n6][assets]     |1280 |36.0   |54.4   |153    |8.1    |2.1    |3.2    |4.6
+|[YOLOv5s6][assets]     |1280 |44.8   |63.7   |385    |8.2    |3.6    |12.6   |16.8
+|[YOLOv5m6][assets]     |1280 |51.3   |69.3   |887    |11.1   |6.8    |35.7   |50.0
+|[YOLOv5l6][assets]     |1280 |53.7   |71.3   |1784   |15.8   |10.5   |76.8   |111.4
+|[YOLOv5x6][assets]<br>+ [TTA][TTA]|1280<br>1536 |55.0<br>**55.8** |72.7<br>**72.7** |3136<br>- |26.2<br>- |19.4<br>- |140.7<br>- |209.8<br>-
+
+<details>
+  <summary>Table Notes (click to expand)</summary>
+
+- All checkpoints are trained to 300 epochs with default settings. Nano and Small models use [hyp.scratch-low.yaml](https://github.com/ultralytics/yolov5/blob/master/data/hyps/hyp.scratch-low.yaml) hyps, all others use [hyp.scratch-high.yaml](https://github.com/ultralytics/yolov5/blob/master/data/hyps/hyp.scratch-high.yaml).
+- **mAP<sup>val</sup>** values are for single-model single-scale on [COCO val2017](http://cocodataset.org) dataset.<br>Reproduce by `python val.py --data coco.yaml --img 640 --conf 0.001 --iou 0.65`
+- **Speed** averaged over COCO val images using a [AWS p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) instance. NMS times (~1 ms/img) not included.<br>Reproduce by `python val.py --data coco.yaml --img 640 --task speed --batch 1`
+- **TTA** [Test Time Augmentation](https://github.com/ultralytics/yolov5/issues/303) includes reflection and scale augmentations.<br>Reproduce by `python val.py --data coco.yaml --img 1536 --iou 0.7 --augment`
+
+</details>
+
+## <div align="center">Contribute</div>
+
+We love your input! We want to make contributing to YOLOv5 as easy and transparent as possible. Please see our [Contributing Guide](CONTRIBUTING.md) to get started, and fill out the [YOLOv5 Survey](https://ultralytics.com/survey?utm_source=github&utm_medium=social&utm_campaign=Survey) to send us feedback on your experiences. Thank you to all our contributors!
+
+<a href="https://github.com/ultralytics/yolov5/graphs/contributors"><img src="https://opencollective.com/ultralytics/contributors.svg?width=990" /></a>
+
+## <div align="center">Contact</div>
+
+For YOLOv5 bugs and feature requests please visit [GitHub Issues](https://github.com/ultralytics/yolov5/issues). For business inquiries or
+professional support requests please visit [https://ultralytics.com/contact](https://ultralytics.com/contact).
+
+<br>
+
+<div align="center">
+    <a href="https://github.com/ultralytics">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-github.png" width="3%"/>
+    </a>
+    <img width="3%" />
+    <a href="https://www.linkedin.com/company/ultralytics">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-linkedin.png" width="3%"/>
+    </a>
+    <img width="3%" />
+    <a href="https://twitter.com/ultralytics">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-twitter.png" width="3%"/>
+    </a>
+    <img width="3%" />
+    <a href="https://www.producthunt.com/@glenn_jocher">
+    <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-producthunt.png" width="3%"/>
+    </a>
+    <img width="3%" />
+    <a href="https://youtube.com/ultralytics">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-youtube.png" width="3%"/>
+    </a>
+    <img width="3%" />
+    <a href="https://www.facebook.com/ultralytics">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-facebook.png" width="3%"/>
+    </a>
+    <img width="3%" />
+    <a href="https://www.instagram.com/ultralytics/">
+        <img src="https://github.com/ultralytics/yolov5/releases/download/v1.0/logo-social-instagram.png" width="3%"/>
+    </a>
+</div>
+
+[assets]: https://github.com/ultralytics/yolov5/releases
+[tta]: https://github.com/ultralytics/yolov5/issues/303

+ 260 - 0
code/yolov5/detect.py

@@ -0,0 +1,260 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Run inference on images, videos, directories, streams, etc.
+
+Usage - sources:
+    $ python path/to/detect.py --weights yolov5s.pt --source 0              # webcam
+                                                             img.jpg        # image
+                                                             vid.mp4        # video
+                                                             path/          # directory
+                                                             path/*.jpg     # glob
+                                                             'https://youtu.be/Zgi9g1ksQHc'  # YouTube
+                                                             'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP stream
+
+Usage - formats:
+    $ python path/to/detect.py --weights yolov5s.pt                 # PyTorch
+                                         yolov5s.torchscript        # TorchScript
+                                         yolov5s.onnx               # ONNX Runtime or OpenCV DNN with --dnn
+                                         yolov5s.xml                # OpenVINO
+                                         yolov5s.engine             # TensorRT
+                                         yolov5s.mlmodel            # CoreML (macOS-only)
+                                         yolov5s_saved_model        # TensorFlow SavedModel
+                                         yolov5s.pb                 # TensorFlow GraphDef
+                                         yolov5s.tflite             # TensorFlow Lite
+                                         yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
+"""
+
+import argparse
+import os
+import sys
+from pathlib import Path
+
+import torch
+import torch.backends.cudnn as cudnn
+
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[0]  # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+    sys.path.append(str(ROOT))  # add ROOT to PATH
+ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
+
+from models.common import DetectMultiBackend
+from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
+from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
+                           increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
+from utils.plots import Annotator, colors, save_one_box
+from utils.torch_utils import select_device, time_sync
+
+
+@torch.no_grad()
+def run(
+        weights=ROOT / 'yolov5s.pt',  # model.pt path(s)
+        source=ROOT / 'data/images',  # file/dir/URL/glob, 0 for webcam
+        data=ROOT / 'data/coco128.yaml',  # dataset.yaml path
+        imgsz=(640, 640),  # inference size (height, width)
+        conf_thres=0.25,  # confidence threshold
+        iou_thres=0.45,  # NMS IOU threshold
+        max_det=1000,  # maximum detections per image
+        device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
+        view_img=False,  # show results
+        save_txt=False,  # save results to *.txt
+        save_conf=False,  # save confidences in --save-txt labels
+        save_crop=False,  # save cropped prediction boxes
+        nosave=False,  # do not save images/videos
+        classes=None,  # filter by class: --class 0, or --class 0 2 3
+        agnostic_nms=False,  # class-agnostic NMS
+        augment=False,  # augmented inference
+        visualize=False,  # visualize features
+        update=False,  # update all models
+        project=ROOT / 'runs/detect',  # save results to project/name
+        name='exp',  # save results to project/name
+        exist_ok=False,  # existing project/name ok, do not increment
+        line_thickness=3,  # bounding box thickness (pixels)
+        hide_labels=False,  # hide labels
+        hide_conf=False,  # hide confidences
+        half=False,  # use FP16 half-precision inference
+        dnn=False,  # use OpenCV DNN for ONNX inference
+):
+    source = str(source)
+    save_img = not nosave and not source.endswith('.txt')  # 判断nosave 以及source是否为txt文件
+    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) # 是否是图像或者视频文件
+    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) # 是否是网络链接
+    webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) # 是否启用网络摄像头
+    if is_url and is_file:
+        source = check_file(source)  # download 下载文件
+
+    # Directories
+    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run  生成增量文件夹 run/detect/exp
+    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir 创建文件夹
+
+    # Load model
+    device = select_device(device) # 选择设备(GPU or CPU)
+    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) #检测编译框架,根据不同的编译框架读取不同类型的权重文件 pytorch、tensorflow、tensorrt等
+    stride, names, pt = model.stride, model.names, model.pt
+    imgsz = check_img_size(imgsz, s=stride)  # check image size 检查输入图片的尺寸是否能被 stride(32) 整除,如果不能则调整图片大小后返回
+
+    # Dataloader
+    if webcam: # 如果开启摄像头
+        view_img = check_imshow()
+        cudnn.benchmark = True  # set True to speed up constant image size inference
+        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
+        bs = len(dataset)  # batch_size
+    else:
+        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt) #直接从source文件夹下读取所有图片名称
+        bs = 1  # batch_size
+    vid_path, vid_writer = [None] * bs, [None] * bs
+
+    # Run inference
+    model.warmup(imgsz=(1 if pt else bs, 3, *imgsz))  # warmup 模型预热
+    dt, seen = [0.0, 0.0, 0.0], 0
+    for path, im, im0s, vid_cap, s in dataset: # path:图片路径 ,im:缩放后的图片,im0s:未缩放的原图, vid_cap:是否为视频,s:输出信息
+
+        # from collections import Counter
+        # count = Counter(im)
+        t1 = time_sync()
+        im = torch.from_numpy(im).to(device)
+        im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32 半精度/全精度
+        im /= 255  # 0 - 255 to 0.0 - 1.0 归一化
+        if len(im.shape) == 3: # 增加一个维度
+            im = im[None]  # expand for batch dim[1, 3, 640, 480]
+        t2 = time_sync()
+        dt[0] += t2 - t1
+
+        # Inference
+        visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False # 是否增量生成文件夹 run/detect/exp
+
+
+        pred = model(im, augment=augment, visualize=visualize) # 图片推理 [1,18900,85]=>([1, 3*(80*80+40*40+20*20), x,y,w,h,c,classes(80)])
+        t3 = time_sync()
+        dt[1] += t3 - t2
+
+        # NMS
+        pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) #NMS非极大抑制
+        dt[2] += time_sync() - t3
+
+        # Second-stage classifier (optional)
+        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
+
+        # Process predictions  对每个预测框做处理
+        for i, det in enumerate(pred):  # per image
+            seen += 1
+            if webcam:  # batch_size >= 1 如果输入源时webcam则batch_size>=1,取出dataset中的一张图片
+                p, im0, frame = path[i], im0s[i].copy(), dataset.count
+                s += f'{i}: '
+            else:
+                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) # frame:视频流
+
+            p = Path(p)  # to Path
+            save_path = str(save_dir / p.name)  # im.jpg 结果图片路径
+            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # im.txt 结果坐标信息txt文件路径
+            s += '%gx%g ' % im.shape[2:]  # print string
+            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
+            imc = im0.copy() if save_crop else im0  # for save_crop
+            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
+            if len(det):
+                # Rescale boxes from img_size to im0 size
+                det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round() #将预测信息隐射到原图
+
+                # Print results
+                for c in det[:, -1].unique(): # 打印检测到的类别数量
+                    n = (det[:, -1] == c).sum()  # detections per class
+                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
+
+                # Write results
+                for *xyxy, conf, cls in reversed(det):
+                    if save_txt:  # Write to file 保存结果到txt文件
+                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
+                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
+                        with open(txt_path + '.txt', 'a') as f:
+                            f.write(('%g ' * len(line)).rstrip() % line + '\n')
+
+                    if save_img or save_crop or view_img:  # Add bbox to image 在图片上画框展示
+                        c = int(cls)  # integer class
+                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
+                        annotator.box_label(xyxy, label, color=colors(c, True))
+                        if save_crop:
+                            save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) #在原图上画框+将预测出来的目标剪切出来,保存成图片
+
+            # Stream results
+            im0 = annotator.result()
+            if view_img: # 显示图片
+                cv2.imshow(str(p), im0)
+                cv2.waitKey(1)  # 1 millisecond
+
+            # Save results (image with detections)
+            if save_img: # 保存图片
+                if dataset.mode == 'image':
+                    cv2.imwrite(save_path, im0)
+                else:  # 'video' or 'stream'
+                    if vid_path[i] != save_path:  # new video
+                        vid_path[i] = save_path
+                        if isinstance(vid_writer[i], cv2.VideoWriter):
+                            vid_writer[i].release()  # release previous video writer
+                        if vid_cap:  # video
+                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
+                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+                        else:  # stream
+                            fps, w, h = 30, im0.shape[1], im0.shape[0]
+                        save_path = str(Path(save_path).with_suffix('.mp4'))  # force *.mp4 suffix on results videos
+                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
+                    vid_writer[i].write(im0)
+
+        # Print time (inference-only)
+        LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
+
+    # Print results
+    t = tuple(x / seen * 1E3 for x in dt)  # speeds per image 打印图片检测速度
+    LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
+    if save_txt or save_img: # 保存txt文件或者时保存图片
+        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
+        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
+    if update:
+        strip_optimizer(weights)  # update model (to fix SourceChangeWarning)
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/shjd2/weights/best.pt', help='model path(s)')
+    # parser.add_argument('--source', type=str, default='rtsp://admin:sunwin2019@192.168.20.240:554/h265/ch1/main/av_stream', help='file/dir/URL/glob, 0 for webcam')
+    parser.add_argument('--source', type=str, default='/data/fengyang/sunwin/code/yolov5/test1', help='file/dir/URL/glob, 0 for webcam')
+    # parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
+    # parser.add_argument('--data', type=str, default=ROOT / 'data/ccd r    oco128.yaml', help='(optional) dataset.yaml path')
+    parser.add_argument('--data', type=str, default='/data2/fengyang/sunwin/data/image/shanghai_jiading/yolo_txt/shanghai_jiading.yaml', help='(optional) dataset.yaml path')
+    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
+    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') # 置信度阈值
+    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') # nms的iou阈值
+    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
+    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+    parser.add_argument('--view-img', action='store_true', help='show results')
+    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
+    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
+    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
+    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
+    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
+    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') # 进行nms是否也去除不同类别之间的框,默认为False
+    parser.add_argument('--augment', action='store_true', help='augmented inference') #推理时进行多尺度、翻转等操作推理
+    parser.add_argument('--visualize', action='store_true', help='visualize features')
+    parser.add_argument('--update', action='store_true', help='update all models')
+    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
+    parser.add_argument('--name', default='exp', help='save results to project/name')
+    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
+    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
+    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
+    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
+    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
+    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
+    opt = parser.parse_args()
+    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
+    print_args(vars(opt))
+    return opt
+
+
+def main(opt):
+    check_requirements(exclude=('tensorboard', 'thop'))
+    run(**vars(opt))
+
+
+if __name__ == "__main__":
+    opt = parse_opt()
+    main(opt)

+ 241 - 0
code/yolov5/detect1.py

@@ -0,0 +1,241 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Run inference on images, videos, directories, streams, etc.
+
+Usage - sources:
+    $ python path/to/detect.py --weights yolov5s.pt --source 0              # webcam
+                                                             img.jpg        # image
+                                                             vid.mp4        # video
+                                                             path/          # directory
+                                                             path/*.jpg     # glob
+                                                             'https://youtu.be/Zgi9g1ksQHc'  # YouTube
+                                                             'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP stream
+
+Usage - formats:
+    $ python path/to/detect.py --weights yolov5s.pt                 # PyTorch
+                                         yolov5s.torchscript        # TorchScript
+                                         yolov5s.onnx               # ONNX Runtime or OpenCV DNN with --dnn
+                                         yolov5s.xml                # OpenVINO
+                                         yolov5s.engine             # TensorRT
+                                         yolov5s.mlmodel            # CoreML (macOS-only)
+                                         yolov5s_saved_model        # TensorFlow SavedModel
+                                         yolov5s.pb                 # TensorFlow GraphDef
+                                         yolov5s.tflite             # TensorFlow Lite
+                                         yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
+"""
+
+import argparse
+import os
+import sys
+from pathlib import Path
+
+import torch
+import torch.backends.cudnn as cudnn
+
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[0]  # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+    sys.path.append(str(ROOT))  # add ROOT to PATH
+ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
+
+from models.common import DetectMultiBackend
+from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
+from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
+                           increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
+from utils.plots import Annotator, colors, save_one_box
+from utils.torch_utils import select_device, time_sync
+
+
+@torch.no_grad()
+def run(
+        weights=ROOT / 'yolov5s.pt',  # model.pt path(s)
+        source=ROOT / 'data/images',  # file/dir/URL/glob, 0 for webcam
+        data=ROOT / 'data/coco128.yaml',  # dataset.yaml path
+        imgsz=(640, 640),  # inference size (height, width)
+        conf_thres=0.25,  # confidence threshold
+        iou_thres=0.45,  # NMS IOU threshold
+        max_det=1000,  # maximum detections per image
+        device='',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
+        view_img=False,  # show results
+        save_txt=False,  # save results to *.txt
+        save_conf=False,  # save confidences in --save-txt labels
+        save_crop=False,  # save cropped prediction boxes
+        nosave=False,  # do not save images/videos
+        classes=None,  # filter by class: --class 0, or --class 0 2 3
+        agnostic_nms=False,  # class-agnostic NMS
+        augment=False,  # augmented inference
+        visualize=False,  # visualize features
+        update=False,  # update all models
+        project=ROOT / 'runs/detect',  # save results to project/name
+        name='exp',  # save results to project/name
+        exist_ok=False,  # existing project/name ok, do not increment
+        line_thickness=3,  # bounding box thickness (pixels)
+        hide_labels=False,  # hide labels
+        hide_conf=False,  # hide confidences
+        half=False,  # use FP16 half-precision inference
+        dnn=False,  # use OpenCV DNN for ONNX inference
+):
+    source = str(source)
+    save_img = not nosave and not source.endswith('.txt')  # 判断nosave 以�source是�为txt文件
+    is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) # 是�是图�或者视频文�    is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://')) # 是�是网络链�    webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file) # 是��用网络摄��    if is_url and is_file:
+        source = check_file(source)  # download 下载文件
+
+    # Directories
+    save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run  生�增�文件�run/detect/exp
+    (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir 创建文件�
+    # Load model
+    device = select_device(device) # 选择设备(GPU or CPU�    model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half) #检测编译框架,根���的编译框架读���类型的��文件 pytorch�tensorflow�tensorrt�    stride, names, pt = model.stride, model.names, model.pt
+    imgsz = check_img_size(imgsz, s=stride)  # check image size 检查输入图片的尺寸是�能被 stride(32) 整除,如果�能则调整图片大��返�
+    # Dataloader
+    if webcam: # 如果开�摄�头
+        view_img = check_imshow()
+        cudnn.benchmark = True  # set True to speed up constant image size inference
+        dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
+        bs = len(dataset)  # batch_size
+    else:
+        dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt) #直接从source文件夹下读�所有图片��        bs = 1  # batch_size
+    vid_path, vid_writer = [None] * bs, [None] * bs
+
+    # Run inference
+    model.warmup(imgsz=(1 if pt else bs, 3, *imgsz))  # warmup 模型预热
+    dt, seen = [0.0, 0.0, 0.0], 0
+    for path, im, im0s, vid_cap, s in dataset: # path:图片路径 ,im:缩放�的图片,im0s:未缩放的原图, vid_cap:是�为视频,s:输出信�
+        # from collections import Counter
+        # count = Counter(im)
+        t1 = time_sync()
+        im = torch.from_numpy(im).to(device)
+        im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32 �精�全精�        im /= 255  # 0 - 255 to 0.0 - 1.0 归一�        if len(im.shape) == 3: # 增加一个维�            im = im[None]  # expand for batch dim[1, 3, 640, 480]
+        t2 = time_sync()
+        dt[0] += t2 - t1
+
+        # Inference
+        visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False # 是�增�生�文件�run/detect/exp
+
+
+        pred = model(im, augment=augment, visualize=visualize) # 图片推� [1,18900,85]=>([1, 3*(80*80+40*40+20*20), x,y,w,h,c,classes(80)])
+        t3 = time_sync()
+        dt[1] += t3 - t2
+
+        # NMS
+        pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) #NMS��大抑�        dt[2] += time_sync() - t3
+
+        # Second-stage classifier (optional)
+        # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
+
+        # Process predictions  对�个预测框�处�        for i, det in enumerate(pred):  # per image
+            seen += 1
+            if webcam:  # batch_size >= 1 如果输入�时webcam则batch_size>=1,�出dataset中的一张图�                p, im0, frame = path[i], im0s[i].copy(), dataset.count
+                s += f'{i}: '
+            else:
+                p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0) # frame:视频�
+
+            p = Path(p)  # to Path
+            save_path = str(save_dir / p.name)  # im.jpg 结果图片路径
+            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # im.txt 结果�标信�txt文件路径
+            s += '%gx%g ' % im.shape[2:]  # print string
+            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
+            imc = im0.copy() if save_crop else im0  # for save_crop
+            annotator = Annotator(im0, line_width=line_thickness, example=str(names))
+            if len(det):
+                # Rescale boxes from img_size to im0 size
+                det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round() #将预测信��射到原图
+
+                # Print results
+                for c in det[:, -1].unique(): # 打�检测到的类别数�                    n = (det[:, -1] == c).sum()  # detections per class
+                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string
+
+                # Write results
+                for *xyxy, conf, cls in reversed(det):
+                    if save_txt:  # Write to file �存结果到txt文件
+                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
+                        line = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label format
+                        with open(txt_path + '.txt', 'a') as f:
+                            f.write(('%g ' * len(line)).rstrip() % line + '\n')
+
+                    if save_img or save_crop or view_img:  # Add bbox to image 在图片上画框展示
+                        c = int(cls)  # integer class
+                        label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
+                        annotator.box_label(xyxy, label, color=colors(c, True))
+                        if save_crop:
+                            save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) #在原图上画框+将预测出�的目标剪切出�,�存�图片
+
+            # Stream results
+            im0 = annotator.result()
+            if view_img: # 显示图片
+                cv2.imshow(str(p), im0)
+                cv2.waitKey(1)  # 1 millisecond
+
+            # Save results (image with detections)
+            if save_img: # �存图片
+                if dataset.mode == 'image':
+                    cv2.imwrite(save_path, im0)
+                else:  # 'video' or 'stream'
+                    if vid_path[i] != save_path:  # new video
+                        vid_path[i] = save_path
+                        if isinstance(vid_writer[i], cv2.VideoWriter):
+                            vid_writer[i].release()  # release previous video writer
+                        if vid_cap:  # video
+                            fps = vid_cap.get(cv2.CAP_PROP_FPS)
+                            w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+                            h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+                        else:  # stream
+                            fps, w, h = 30, im0.shape[1], im0.shape[0]
+                        save_path = str(Path(save_path).with_suffix('.mp4'))  # force *.mp4 suffix on results videos
+                        vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
+                    vid_writer[i].write(im0)
+
+        # Print time (inference-only)
+        LOGGER.info(f'{s}Done. ({t3 - t2:.3f}s)')
+
+    # Print results
+    t = tuple(x / seen * 1E3 for x in dt)  # speeds per image 打�图片检测速度
+    LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
+    if save_txt or save_img: # �存txt文件或者时�存图片
+        s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
+        LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
+    if update:
+        strip_optimizer(weights)  # update model (to fix SourceChangeWarning)
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/shjd2/weights/best.pt', help='model path(s)')
+    parser.add_argument('--source', type=str, default='/data/fengyang/sunwin/code/yolov5/test1', help='file/dir/URL/glob, 0 for webcam')
+    # parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob, 0 for webcam')
+    # parser.add_argument('--data', type=str, default='/data2/fengyang/sunwin/data/image/shanghai_jiading/yolo_txt/shanghai_jiading.yaml', help='(optional) dataset.yaml path')
+    parser.add_argument('--data', type=str, default='/data/fengyang/sunwin/data/helmet_fall_phone_delete_work/helmet_fall_phone.yaml', help='(optional) dataset.yaml path')
+    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
+    parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold') # 置信度阈�    parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold') # nms的iou阈�    parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
+    parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+    parser.add_argument('--view-img', action='store_true', help='show results')
+    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
+    parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
+    parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
+    parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
+    parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
+    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') # 进行nms是�也去除��类别之间的框,默认为False
+    parser.add_argument('--augment', action='store_true', help='augmented inference') #推�时进行多尺度�翻转等�作推�
+    parser.add_argument('--visualize', action='store_true', help='visualize features')
+    parser.add_argument('--update', action='store_true', help='update all models')
+    parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
+    parser.add_argument('--name', default='exp', help='save results to project/name')
+    parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
+    parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
+    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
+    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
+    parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
+    parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
+    opt = parser.parse_args()
+    opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expand
+    print_args(vars(opt))
+    return opt
+
+
+def main(opt):
+    check_requirements(exclude=('tensorboard', 'thop'))
+    run(**vars(opt))
+
+
+if __name__ == "__main__":
+    opt = parse_opt()
+    main(opt)

+ 596 - 0
code/yolov5/export.py

@@ -0,0 +1,596 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit
+
+Format                      | `export.py --include`         | Model
+---                         | ---                           | ---
+PyTorch                     | -                             | yolov5s.pt
+TorchScript                 | `torchscript`                 | yolov5s.torchscript
+ONNX                        | `onnx`                        | yolov5s.onnx
+OpenVINO                    | `openvino`                    | yolov5s_openvino_model/
+TensorRT                    | `engine`                      | yolov5s.engine
+CoreML                      | `coreml`                      | yolov5s.mlmodel
+TensorFlow SavedModel       | `saved_model`                 | yolov5s_saved_model/
+TensorFlow GraphDef         | `pb`                          | yolov5s.pb
+TensorFlow Lite             | `tflite`                      | yolov5s.tflite
+TensorFlow Edge TPU         | `edgetpu`                     | yolov5s_edgetpu.tflite
+TensorFlow.js               | `tfjs`                        | yolov5s_web_model/
+
+Requirements:
+    $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu  # CPU
+    $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow  # GPU
+
+Usage:
+    $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
+
+Inference:
+    $ python path/to/detect.py --weights yolov5s.pt                 # PyTorch
+                                         yolov5s.torchscript        # TorchScript
+                                         yolov5s.onnx               # ONNX Runtime or OpenCV DNN with --dnn
+                                         yolov5s.xml                # OpenVINO
+                                         yolov5s.engine             # TensorRT
+                                         yolov5s.mlmodel            # CoreML (macOS-only)
+                                         yolov5s_saved_model        # TensorFlow SavedModel
+                                         yolov5s.pb                 # TensorFlow GraphDef
+                                         yolov5s.tflite             # TensorFlow Lite
+                                         yolov5s_edgetpu.tflite     # TensorFlow Edge TPU
+
+TensorFlow.js:
+    $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
+    $ npm install
+    $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model
+    $ npm start
+"""
+
+import argparse
+import json
+import os
+import platform
+import subprocess
+import sys
+import time
+import warnings
+from pathlib import Path
+
+import pandas as pd
+import torch
+from torch.utils.mobile_optimizer import optimize_for_mobile
+
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[0]  # YOLOv5 root directory
+if str(ROOT) not in sys.path:
+    sys.path.append(str(ROOT))  # add ROOT to PATH
+if platform.system() != 'Windows':
+    ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
+
+from models.experimental import attempt_load
+from models.yolo import Detect
+from utils.datasets import LoadImages
+from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
+                           file_size, print_args, url2file)
+from utils.torch_utils import select_device
+
+
+def export_formats():
+    # YOLOv5 export formats
+    x = [
+        ['PyTorch', '-', '.pt', True],
+        ['TorchScript', 'torchscript', '.torchscript', True],
+        ['ONNX', 'onnx', '.onnx', True],
+        ['OpenVINO', 'openvino', '_openvino_model', False],
+        ['TensorRT', 'engine', '.engine', True],
+        ['CoreML', 'coreml', '.mlmodel', False],
+        ['TensorFlow SavedModel', 'saved_model', '_saved_model', True],
+        ['TensorFlow GraphDef', 'pb', '.pb', True],
+        ['TensorFlow Lite', 'tflite', '.tflite', False],
+        ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False],
+        ['TensorFlow.js', 'tfjs', '_web_model', False],]
+    return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU'])
+
+
+def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
+    # YOLOv5 TorchScript model export
+    try:
+        LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
+        f = file.with_suffix('.torchscript')
+
+        ts = torch.jit.trace(model, im, strict=False)
+        d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
+        extra_files = {'config.txt': json.dumps(d)}  # torch._C.ExtraFilesMap()
+        if optimize:  # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
+            optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
+        else:
+            ts.save(str(f), _extra_files=extra_files)
+
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'{prefix} export failure: {e}')
+
+
+def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
+    # YOLOv5 ONNX export
+    try:
+        check_requirements(('onnx',))
+        import onnx
+
+        LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
+        f = file.with_suffix('.onnx')
+
+        torch.onnx.export(
+            model,
+            im,
+            f,
+            verbose=False,
+            opset_version=opset,
+            training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
+            do_constant_folding=not train,
+            input_names=['images'],
+            output_names=['output'],
+            dynamic_axes={
+                'images': {
+                    0: 'batch',
+                    2: 'height',
+                    3: 'width'},  # shape(1,3,640,640)
+                'output': {
+                    0: 'batch',
+                    1: 'anchors'}  # shape(1,25200,85)
+            } if dynamic else None)
+
+        # Checks
+        model_onnx = onnx.load(f)  # load onnx model
+        onnx.checker.check_model(model_onnx)  # check onnx model
+
+        # Metadata
+        d = {'stride': int(max(model.stride)), 'names': model.names}
+        for k, v in d.items():
+            meta = model_onnx.metadata_props.add()
+            meta.key, meta.value = k, str(v)
+        onnx.save(model_onnx, f)
+
+        # Simplify
+        if simplify:
+            try:
+                check_requirements(('onnx-simplifier',))
+                import onnxsim
+
+                LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
+                model_onnx, check = onnxsim.simplify(model_onnx,
+                                                     dynamic_input_shape=dynamic,
+                                                     input_shapes={'images': list(im.shape)} if dynamic else None)
+                assert check, 'assert check failed'
+                onnx.save(model_onnx, f)
+            except Exception as e:
+                LOGGER.info(f'{prefix} simplifier failure: {e}')
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'{prefix} export failure: {e}')
+
+
+def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')):
+    # YOLOv5 OpenVINO export
+    try:
+        check_requirements(('openvino-dev',))  # requires openvino-dev: https://pypi.org/project/openvino-dev/
+        import openvino.inference_engine as ie
+
+        LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
+        f = str(file).replace('.pt', '_openvino_model' + os.sep)
+
+        cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f}"
+        subprocess.check_output(cmd, shell=True)
+
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+
+
+def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
+    # YOLOv5 CoreML export
+    try:
+        check_requirements(('coremltools',))
+        import coremltools as ct
+
+        LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
+        f = file.with_suffix('.mlmodel')
+
+        ts = torch.jit.trace(model, im, strict=False)  # TorchScript model
+        ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
+        bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
+        if bits < 32:
+            if platform.system() == 'Darwin':  # quantization only supported on macOS
+                with warnings.catch_warnings():
+                    warnings.filterwarnings("ignore", category=DeprecationWarning)  # suppress numpy==1.20 float warning
+                    ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
+            else:
+                print(f'{prefix} quantization only supported on macOS, skipping...')
+        ct_model.save(f)
+
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return ct_model, f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+        return None, None
+
+
+def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
+    # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
+    try:
+        assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
+        check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
+        import tensorrt as trt
+
+        if trt.__version__[0] == '7':  # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
+            grid = model.model[-1].anchor_grid
+            model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
+            export_onnx(model, im, file, 12, train, False, simplify)  # opset 12
+            model.model[-1].anchor_grid = grid
+        else:  # TensorRT >= 8
+            check_version(trt.__version__, '8.0.0', hard=True)  # require tensorrt>=8.0.0
+            export_onnx(model, im, file, 13, train, False, simplify)  # opset 13
+        onnx = file.with_suffix('.onnx')
+
+        LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
+        assert onnx.exists(), f'failed to export ONNX file: {onnx}'
+        f = file.with_suffix('.engine')  # TensorRT engine file
+        logger = trt.Logger(trt.Logger.INFO)
+        if verbose:
+            logger.min_severity = trt.Logger.Severity.VERBOSE
+
+        builder = trt.Builder(logger)
+        config = builder.create_builder_config()
+        config.max_workspace_size = workspace * 1 << 30
+        # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30)  # fix TRT 8.4 deprecation notice
+
+        flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
+        network = builder.create_network(flag)
+        parser = trt.OnnxParser(network, logger)
+        if not parser.parse_from_file(str(onnx)):
+            raise RuntimeError(f'failed to load ONNX file: {onnx}')
+
+        inputs = [network.get_input(i) for i in range(network.num_inputs)]
+        outputs = [network.get_output(i) for i in range(network.num_outputs)]
+        LOGGER.info(f'{prefix} Network Description:')
+        for inp in inputs:
+            LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
+        for out in outputs:
+            LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
+
+        LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 else 32} engine in {f}')
+        if builder.platform_has_fast_fp16:
+            config.set_flag(trt.BuilderFlag.FP16)
+        with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
+            t.write(engine.serialize())
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+
+
+def export_saved_model(model,
+                       im,
+                       file,
+                       dynamic,
+                       tf_nms=False,
+                       agnostic_nms=False,
+                       topk_per_class=100,
+                       topk_all=100,
+                       iou_thres=0.45,
+                       conf_thres=0.25,
+                       keras=False,
+                       prefix=colorstr('TensorFlow SavedModel:')):
+    # YOLOv5 TensorFlow SavedModel export
+    try:
+        import tensorflow as tf
+        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
+
+        from models.tf import TFDetect, TFModel
+
+        LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+        f = str(file).replace('.pt', '_saved_model')
+        batch_size, ch, *imgsz = list(im.shape)  # BCHW
+
+        tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
+        im = tf.zeros((batch_size, *imgsz, ch))  # BHWC order for TensorFlow
+        _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
+        inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
+        outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
+        keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
+        keras_model.trainable = False
+        keras_model.summary()
+        if keras:
+            keras_model.save(f, save_format='tf')
+        else:
+            spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
+            m = tf.function(lambda x: keras_model(x))  # full model
+            m = m.get_concrete_function(spec)
+            frozen_func = convert_variables_to_constants_v2(m)
+            tfm = tf.Module()
+            tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
+            tfm.__call__(im)
+            tf.saved_model.save(tfm,
+                                f,
+                                options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)
+                                if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return keras_model, f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+        return None, None
+
+
+def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
+    # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
+    try:
+        import tensorflow as tf
+        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
+
+        LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+        f = file.with_suffix('.pb')
+
+        m = tf.function(lambda x: keras_model(x))  # full model
+        m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
+        frozen_func = convert_variables_to_constants_v2(m)
+        frozen_func.graph.as_graph_def()
+        tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
+
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+
+
+def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
+    # YOLOv5 TensorFlow Lite export
+    try:
+        import tensorflow as tf
+
+        LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
+        batch_size, ch, *imgsz = list(im.shape)  # BCHW
+        f = str(file).replace('.pt', '-fp16.tflite')
+
+        converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+        converter.target_spec.supported_types = [tf.float16]
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        if int8:
+            from models.tf import representative_dataset_gen
+            dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False)  # representative data
+            converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
+            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+            converter.target_spec.supported_types = []
+            converter.inference_input_type = tf.uint8  # or tf.int8
+            converter.inference_output_type = tf.uint8  # or tf.int8
+            converter.experimental_new_quantizer = True
+            f = str(file).replace('.pt', '-int8.tflite')
+        if nms or agnostic_nms:
+            converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
+
+        tflite_model = converter.convert()
+        open(f, "wb").write(tflite_model)
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+
+
+def export_edgetpu(keras_model, im, file, prefix=colorstr('Edge TPU:')):
+    # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
+    try:
+        cmd = 'edgetpu_compiler --version'
+        help_url = 'https://coral.ai/docs/edgetpu/compiler/'
+        assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
+        if subprocess.run(cmd + ' >/dev/null', shell=True).returncode != 0:
+            LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
+            sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0  # sudo installed on system
+            for c in (
+                    'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
+                    'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
+                    'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
+                subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
+        ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
+
+        LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
+        f = str(file).replace('.pt', '-int8_edgetpu.tflite')  # Edge TPU model
+        f_tfl = str(file).replace('.pt', '-int8.tflite')  # TFLite model
+
+        cmd = f"edgetpu_compiler -s -o {file.parent} {f_tfl}"
+        subprocess.run(cmd, shell=True, check=True)
+
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+
+
+def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
+    # YOLOv5 TensorFlow.js export
+    try:
+        check_requirements(('tensorflowjs',))
+        import re
+
+        import tensorflowjs as tfjs
+
+        LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
+        f = str(file).replace('.pt', '_web_model')  # js dir
+        f_pb = file.with_suffix('.pb')  # *.pb path
+        f_json = f + '/model.json'  # *.json path
+
+        cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
+              f'--output_node_names="Identity,Identity_1,Identity_2,Identity_3" {f_pb} {f}'
+        subprocess.run(cmd, shell=True)
+
+        with open(f_json) as j:
+            json = j.read()
+        with open(f_json, 'w') as j:  # sort JSON Identity_* in ascending order
+            subst = re.sub(
+                r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
+                r'"Identity.?.?": {"name": "Identity.?.?"}, '
+                r'"Identity.?.?": {"name": "Identity.?.?"}, '
+                r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
+                r'"Identity_1": {"name": "Identity_1"}, '
+                r'"Identity_2": {"name": "Identity_2"}, '
+                r'"Identity_3": {"name": "Identity_3"}}}', json)
+            j.write(subst)
+
+        LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+        return f
+    except Exception as e:
+        LOGGER.info(f'\n{prefix} export failure: {e}')
+
+
+@torch.no_grad()
+def run(
+        data=ROOT / 'data/coco128.yaml',  # 'dataset.yaml path'
+        weights=ROOT / 'yolov5s.pt',  # weights path
+        imgsz=(640, 640),  # image (height, width)
+        batch_size=1,  # batch size
+        device='cpu',  # cuda device, i.e. 0 or 0,1,2,3 or cpu
+        include=('torchscript', 'onnx'),  # include formats
+        half=False,  # FP16 half-precision export
+        inplace=False,  # set YOLOv5 Detect() inplace=True
+        train=False,  # model.train() mode
+        optimize=False,  # TorchScript: optimize for mobile
+        int8=False,  # CoreML/TF INT8 quantization
+        dynamic=False,  # ONNX/TF: dynamic axes
+        simplify=False,  # ONNX: simplify model
+        opset=12,  # ONNX: opset version
+        verbose=False,  # TensorRT: verbose log
+        workspace=4,  # TensorRT: workspace size (GB)
+        nms=False,  # TF: add NMS to model
+        agnostic_nms=False,  # TF: add agnostic NMS to model
+        topk_per_class=100,  # TF.js NMS: topk per class to keep
+        topk_all=100,  # TF.js NMS: topk for all classes to keep
+        iou_thres=0.45,  # TF.js NMS: IoU threshold
+        conf_thres=0.25,  # TF.js NMS: confidence threshold
+):
+    t = time.time()
+    include = [x.lower() for x in include]  # to lowercase
+    formats = tuple(export_formats()['Argument'][1:])  # --include arguments
+    flags = [x in include for x in formats]
+    assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}'
+    jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags  # export booleans
+    file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights)  # PyTorch weights
+
+    # Load PyTorch model
+    device = select_device(device)
+    if half:
+        assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
+    model = attempt_load(weights, map_location=device, inplace=True, fuse=True)  # load FP32 model
+    nc, names = model.nc, model.names  # number of classes, class names
+
+    # Checks
+    imgsz *= 2 if len(imgsz) == 1 else 1  # expand
+    assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'
+
+    # Input
+    gs = int(max(model.stride))  # grid size (max stride)
+    imgsz = [check_img_size(x, gs) for x in imgsz]  # verify img_size are gs-multiples
+    im = torch.zeros(batch_size, 3, *imgsz).to(device)  # image size(1,3,320,192) BCHW iDetection
+
+    # Update model
+    if half and not coreml:
+        im, model = im.half(), model.half()  # to FP16
+    model.train() if train else model.eval()  # training mode = no Detect() layer grid construction
+    for k, m in model.named_modules():
+        if isinstance(m, Detect):
+            m.inplace = inplace
+            m.onnx_dynamic = dynamic
+            m.export = True
+
+    for _ in range(2):
+        y = model(im)  # dry runs
+    shape = tuple(y[0].shape)  # model output shape
+    LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
+
+    # Exports
+    f = [''] * 10  # exported filenames
+    warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)  # suppress TracerWarning
+    if jit:
+        f[0] = export_torchscript(model, im, file, optimize)
+    if engine:  # TensorRT required before ONNX
+        f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose)
+    if onnx or xml:  # OpenVINO requires ONNX
+        f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
+    if xml:  # OpenVINO
+        f[3] = export_openvino(model, im, file)
+    if coreml:
+        _, f[4] = export_coreml(model, im, file, int8, half)
+
+    # TensorFlow Exports
+    if any((saved_model, pb, tflite, edgetpu, tfjs)):
+        if int8 or edgetpu:  # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
+            check_requirements(('flatbuffers==1.12',))  # required before `import tensorflow`
+        assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.'
+        model, f[5] = export_saved_model(model.cpu(),
+                                         im,
+                                         file,
+                                         dynamic,
+                                         tf_nms=nms or agnostic_nms or tfjs,
+                                         agnostic_nms=agnostic_nms or tfjs,
+                                         topk_per_class=topk_per_class,
+                                         topk_all=topk_all,
+                                         conf_thres=conf_thres,
+                                         iou_thres=iou_thres)  # keras model
+        if pb or tfjs:  # pb prerequisite to tfjs
+            f[6] = export_pb(model, im, file)
+        if tflite or edgetpu:
+            f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
+        if edgetpu:
+            f[8] = export_edgetpu(model, im, file)
+        if tfjs:
+            f[9] = export_tfjs(model, im, file)
+
+    # Finish
+    f = [str(x) for x in f if x]  # filter out '' and None
+    if any(f):
+        LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
+                    f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
+                    f"\nDetect:          python detect.py --weights {f[-1]}"
+                    f"\nPyTorch Hub:     model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')"
+                    f"\nValidate:        python val.py --weights {f[-1]}"
+                    f"\nVisualize:       https://netron.app")
+    return f  # return list of exported files/dirs
+
+
+def parse_opt():
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
+    parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')
+    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
+    parser.add_argument('--batch-size', type=int, default=1, help='batch size')
+    parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+    parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
+    parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
+    parser.add_argument('--train', action='store_true', help='model.train() mode')
+    parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
+    parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
+    parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
+    parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
+    parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
+    parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
+    parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
+    parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
+    parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
+    parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
+    parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
+    parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
+    parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
+    parser.add_argument('--include',
+                        nargs='+',
+                        default=['torchscript', 'onnx'],
+                        help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
+    opt = parser.parse_args()
+    print_args(vars(opt))
+    return opt
+
+
+def main(opt):
+    for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
+        run(**vars(opt))
+
+
+if __name__ == "__main__":
+    opt = parse_opt()
+    main(opt)

+ 3 - 0
code/yolov5/export.sh

@@ -0,0 +1,3 @@
+#python export.py --weights runs/train/hebeixuangang/weights/best.pt --img-size 1280 --include onnx
+python export.py --weights ./runs/train/baoanjichang/weights/best.pt --include pb
+

+ 145 - 0
code/yolov5/hubconf.py

@@ -0,0 +1,145 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+PyTorch Hub models https://pytorch.org/hub/ultralytics_yolov5/
+
+Usage:
+    import torch
+    model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
+    model = torch.hub.load('ultralytics/yolov5:master', 'custom', 'path/to/yolov5s.onnx')  # file from branch
+"""
+
+import torch
+
+
+def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    """Creates or loads a YOLOv5 model
+
+    Arguments:
+        name (str): model name 'yolov5s' or path 'path/to/best.pt'
+        pretrained (bool): load pretrained weights into the model
+        channels (int): number of input channels
+        classes (int): number of model classes
+        autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
+        verbose (bool): print all information to screen
+        device (str, torch.device, None): device to use for model parameters
+
+    Returns:
+        YOLOv5 model
+    """
+    from pathlib import Path
+
+    from models.common import AutoShape, DetectMultiBackend
+    from models.yolo import Model
+    from utils.downloads import attempt_download
+    from utils.general import LOGGER, check_requirements, intersect_dicts, logging
+    from utils.torch_utils import select_device
+
+    if not verbose:
+        LOGGER.setLevel(logging.WARNING)
+    check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
+    name = Path(name)
+    path = name.with_suffix('.pt') if name.suffix == '' else name  # checkpoint path
+    try:
+        device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)
+
+        if pretrained and channels == 3 and classes == 80:
+            model = DetectMultiBackend(path, device=device)  # download/load FP32 model
+            # model = models.experimental.attempt_load(path, map_location=device)  # download/load FP32 model
+        else:
+            cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0]  # model.yaml path
+            model = Model(cfg, channels, classes)  # create model
+            if pretrained:
+                ckpt = torch.load(attempt_download(path), map_location=device)  # load
+                csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
+                csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors'])  # intersect
+                model.load_state_dict(csd, strict=False)  # load
+                if len(ckpt['model'].names) == classes:
+                    model.names = ckpt['model'].names  # set class names attribute
+        if autoshape:
+            model = AutoShape(model)  # for file/URI/PIL/cv2/np inputs and NMS
+        return model.to(device)
+
+    except Exception as e:
+        help_url = 'https://github.com/ultralytics/yolov5/issues/36'
+        s = f'{e}. Cache may be out of date, try `force_reload=True` or see {help_url} for help.'
+        raise Exception(s) from e
+
+
+def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
+    # YOLOv5 custom or local model
+    return _create(path, autoshape=autoshape, verbose=verbose, device=device)
+
+
+def yolov5n(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-nano model https://github.com/ultralytics/yolov5
+    return _create('yolov5n', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-small model https://github.com/ultralytics/yolov5
+    return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-medium model https://github.com/ultralytics/yolov5
+    return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-large model https://github.com/ultralytics/yolov5
+    return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-xlarge model https://github.com/ultralytics/yolov5
+    return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5n6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-nano-P6 model https://github.com/ultralytics/yolov5
+    return _create('yolov5n6', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
+    return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
+    return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
+    return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)
+
+
+def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
+    # YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
+    return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)
+
+
+if __name__ == '__main__':
+    model = _create(name='yolov5s', pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)  # pretrained
+    # model = custom(path='path/to/model.pt')  # custom
+
+    # Verify inference
+    from pathlib import Path
+
+    import numpy as np
+    from PIL import Image
+
+    from utils.general import cv2
+
+    imgs = [
+        'data/images/zidane.jpg',  # filename
+        Path('data/images/zidane.jpg'),  # Path
+        'https://ultralytics.com/images/zidane.jpg',  # URI
+        cv2.imread('data/images/bus.jpg')[:, :, ::-1],  # OpenCV
+        Image.open('data/images/bus.jpg'),  # PIL
+        np.zeros((320, 640, 3))]  # numpy
+
+    results = model(imgs, size=320)  # batched inference
+    results.print()
+    results.save()

+ 0 - 0
code/yolov5/models/__init__.py


+ 842 - 0
code/yolov5/models/common.py

@@ -0,0 +1,842 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Common modules
+"""
+
+import json
+import math
+import platform
+import warnings
+from collections import OrderedDict, namedtuple
+from copy import copy
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+import requests
+import torch
+import torch.nn as nn
+import yaml
+from PIL import Image
+from torch.cuda import amp
+
+from utils.datasets import exif_transpose, letterbox
+from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
+                           make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
+from utils.plots import Annotator, colors, save_one_box
+from utils.torch_utils import copy_attr, time_sync
+
+# 为same卷积或same池化自动扩充
+def autopad(k, p=None):  # kernel, padding
+    """
+    用于Conv函数和Classify函数,根据卷积核大小k自动计算卷积核和padding数
+    v5中只有两种卷积:
+        1.下采样卷积:conv3*3 s=2 p=k//2=1
+        2.feature size不变的卷积:conv1*1 s=1 p=k//2=1
+    :param k: 卷积核的kernel_size
+    :type k:
+    :param p:自动计算的pad值
+    :type p:
+    :return:
+    :rtype:
+    """
+    # Pad to 'same'
+    if p is None:
+        p = k // 2 if isinstance(k, int) else (x // 2 for x in k)  # auto-pad
+    return p
+
+
+class Conv(nn.Module):
+    # Standard convolution 标准卷积:conv+BN+SiLU
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        """
+        在Focus、Bottleneck、BottleneckCSP、C3、SPP、DWConv、TransformerBloc等模块中调用的基础组件
+        :param c1:输入的channel值
+        :type c1:
+        :param c2:输出的channel值
+        :type c2:
+        :param k:卷积的kernel_size
+        :type k:
+        :param s:卷积的stride
+        :type s:
+        :param p:卷积的padding数,可以通过autopad自行计算padding数
+        :type p:
+        :param g:卷积的groups数 一般等于1为普通卷积,大于1就是深度可分离卷积
+        :type g:
+        :param act:激活函数类型 True就是SiLU
+        :type act:
+        """
+        super().__init__()
+        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
+        self.bn = nn.BatchNorm2d(c2)
+        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
+
+    def forward(self, x): # 网络的执行顺序是根据 forward 函数决定的
+        return self.act(self.bn(self.conv(x)))
+
+    def forward_fuse(self, x):
+        """
+        用于Model类的fuse函数
+        相较于forward函数去掉了BN层,加速推理,一般用于测试/验证阶段
+        :param x:
+        :type x:
+        :return:
+        :rtype:
+        """
+        return self.act(self.conv(x))
+
+
+class DWConv(Conv):
+    # Depth-wise convolution class
+    def __init__(self, c1, c2, k=1, s=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
+
+
+class TransformerLayer(nn.Module):
+    # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
+    def __init__(self, c, num_heads):
+        super().__init__()
+        self.q = nn.Linear(c, c, bias=False)
+        self.k = nn.Linear(c, c, bias=False)
+        self.v = nn.Linear(c, c, bias=False)
+        self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
+        self.fc1 = nn.Linear(c, c, bias=False)
+        self.fc2 = nn.Linear(c, c, bias=False)
+
+    def forward(self, x):
+        x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
+        x = self.fc2(self.fc1(x)) + x
+        return x
+
+
+class TransformerBlock(nn.Module):
+    # Vision Transformer https://arxiv.org/abs/2010.11929
+    def __init__(self, c1, c2, num_heads, num_layers):
+        super().__init__()
+        self.conv = None
+        if c1 != c2:
+            self.conv = Conv(c1, c2)
+        self.linear = nn.Linear(c2, c2)  # learnable position embedding
+        self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
+        self.c2 = c2
+
+    def forward(self, x):
+        if self.conv is not None:
+            x = self.conv(x)
+        b, _, w, h = x.shape
+        p = x.flatten(2).permute(2, 0, 1)
+        return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
+
+
+class Bottleneck(nn.Module):
+    # Standard bottleneck
+    """
+    由1*1conv、3*3conv、残差块组成
+    """
+    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
+        """
+        在BottleneckCSP、C3、parse_model中调用
+        组件分为两种情况,当shortcut为True时,bottleneck需要在经过1*1卷积和3*3卷积后在经过shortcut
+            当shortcut为False时,bottleneck只需要经过1*1卷积和3*3卷积即可
+        :param c1:输入channel
+        :type c1:
+        :param c2:输出channel
+        :type c2:
+        :param shortcut:是否进行shortcut 默认为True
+        :type shortcut:
+        :param g: 卷积的groups数 等于1普通卷积 大于1深度可分离卷积
+        :type g:
+        :param e:膨胀系数
+        :type e:
+        """
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels 中间层的channel数
+        self.cv1 = Conv(c1, c_, 1, 1) # 第一层卷积输出的channel数为c_
+        self.cv2 = Conv(c_, c2, 3, 1, g=g)# 第二层卷积输入的channel数为c_
+        self.add = shortcut and c1 == c2
+
+    def forward(self, x):
+        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+
+class BottleneckCSP(nn.Module):
+    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
+        """
+        该组件由bottleneck模块和CSP模块组成,此模块与C3模块等效。
+        :param c1:输入channel
+        :type c1:
+        :param c2:输出channel
+        :type c2:
+        :param n:有n个bottleneck
+        :type n:
+        :param shortcut:bottleneck中是shortcut,默认为True
+        :type shortcut:
+        :param g: bottleneck中的groups 等于1,普通卷积 大于1,深度可分离卷积
+        :type g:
+        :param e:bottleneck中的膨胀系数
+        :type e:
+        """
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1) #Conv+BN+SiLU
+        self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+        self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+        self.cv4 = Conv(2 * c_, c2, 1, 1)
+        self.bn = nn.BatchNorm2d(2 * c_)  # applied to cat(cv2, cv3)
+        self.act = nn.SiLU()
+        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) #叠加n次bottleneck
+
+    def forward(self, x):
+        y1 = self.cv3(self.m(self.cv1(x)))
+        y2 = self.cv2(x)
+        return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
+
+
+class C3(nn.Module):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
+        """
+        简化版的bottleneckCSP模块,除了bottleneck部分整个结构只有3个卷积,可以减少参数
+        :param c1: 输入channel
+        :type c1:
+        :param c2: 输出channel
+        :type c2:
+        :param n: 有n个bottleneck
+        :type n:
+        :param shortcut: bottleneck中是否有shortcut,默认为True
+        :type shortcut:
+        :param g: bottleneck中的groups 等于1,普通卷积 大于1,深度可分离卷积
+        :type g:
+        :param e: bottleneck中的膨胀系数
+        :type e:
+        """
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c1, c_, 1, 1)
+        self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
+        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+        # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)))
+
+    def forward(self, x):
+        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
+
+
+class C3TR(C3):
+    # C3 module with TransformerBlock()
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+        super().__init__(c1, c2, n, shortcut, g, e)
+        c_ = int(c2 * e)
+        self.m = TransformerBlock(c_, c_, 4, n)
+
+
+class C3SPP(C3):
+    # C3 module with SPP()
+    def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
+        super().__init__(c1, c2, n, shortcut, g, e)
+        c_ = int(c2 * e)
+        self.m = SPP(c_, c_, k)
+
+
+class C3Ghost(C3):
+    # C3 module with GhostBottleneck()
+    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
+        super().__init__(c1, c2, n, shortcut, g, e)
+        c_ = int(c2 * e)  # hidden channels
+        self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
+
+
+class SPP(nn.Module):
+    # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
+    def __init__(self, c1, c2, k=(5, 9, 13)):
+        """
+        空间金字塔池化
+        :param c1:  输入channel
+        :type c1:
+        :param c2:  输出channel
+        :type c2:
+        :param k:  保存着三个maxpool卷积的kernel_size。默认是(5, 9, 13)
+        :type k:
+        """
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1) # 第一层卷积
+        self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) # 最后一层卷积
+        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) # 中间的maxpool层
+
+    def forward(self, x):
+        x = self.cv1(x)
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
+            return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class SPPF(nn.Module):
+    # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+    def __init__(self, c1, c2, k=5):  # equivalent to SPP(k=(5, 9, 13))
+        """
+        SPP的升级改进版,将5*5,9*9,13*13三个amxpool并行输出的结果改成了3个5*5的maxpool串行输出的结果。结果是提升了计算速度
+        :param c1: 输入channel
+        :type c1:
+        :param c2: 输出channel
+        :type c2:
+        :param k: 卷积的kernel_size
+        :type k:
+        """
+        super().__init__()
+        c_ = c1 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, 1, 1)
+        self.cv2 = Conv(c_ * 4, c2, 1, 1)
+        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        with warnings.catch_warnings():
+            warnings.simplefilter('ignore')  # suppress torch 1.9.0 max_pool2d() warning
+            y1 = self.m(x)
+            y2 = self.m(y1)
+            return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+class Focus(nn.Module): # Focus:把宽度w和高度h的信息整合到c空间中。
+    """
+    Focus组件是为了减少计算量,提升速度。并不能增加网络的精度。
+
+    从高分辨率图片中,周期性的抽出像素点重构到低分辨率图像中,将图像相邻的四个位置进行堆叠,聚焦wh维度信息到c通道空间,提高每个点的感受野,并减少原始信息的丢失。
+    该组件在减少计算量,提升速度的前提下减少原始信息的丢失。
+    """
+    # Focus wh information into c-space
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
+        """
+
+        :param c1: 输入的channel数
+        :type c1:
+        :param c2: Focus输出的channel数
+        :type c2:
+        :param k: 卷积的kernel_size
+        :type k:
+        :param s: 卷积的stride
+        :type s:
+        :param p: 卷积的padding
+        :type p:
+        :param g: 卷积的groups 等于1为普通卷积 大于1为深度可分离卷积
+        :type g:
+        :param act:激活函数类型 True:SiLU/Swish False:不使用激活函数
+        :type act:
+        """
+        super().__init__()
+        self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
+        # self.contract = Contract(gain=2)
+
+    def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+        return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
+        # return self.conv(self.contract(x))
+
+
+class GhostConv(nn.Module):
+    # Ghost Convolution https://github.com/huawei-noah/ghostnet
+    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):  # ch_in, ch_out, kernel, stride, groups
+        super().__init__()
+        c_ = c2 // 2  # hidden channels
+        self.cv1 = Conv(c1, c_, k, s, None, g, act)
+        self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
+
+    def forward(self, x):
+        y = self.cv1(x)
+        return torch.cat((y, self.cv2(y)), 1)
+
+
+class GhostBottleneck(nn.Module):
+    # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
+    def __init__(self, c1, c2, k=3, s=1):  # ch_in, ch_out, kernel, stride
+        super().__init__()
+        c_ = c2 // 2
+        self.conv = nn.Sequential(
+            GhostConv(c1, c_, 1, 1),  # pw
+            DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(),  # dw
+            GhostConv(c_, c2, 1, 1, act=False))  # pw-linear
+        self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
+                                                                            act=False)) if s == 2 else nn.Identity()
+
+    def forward(self, x):
+        return self.conv(x) + self.shortcut(x)
+
+
+class Contract(nn.Module):
+    # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
+    def __init__(self, gain=2):
+        """
+        Focus模块的辅助函数,目的是改变输入特征的shape w和h维度的数据减半后将channel通道数提升4倍
+        :param gain:
+        :type gain:
+        """
+        super().__init__()
+        self.gain = gain
+
+    def forward(self, x):
+        b, c, h, w = x.size()  # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
+        s = self.gain
+        x = x.view(b, c, h // s, s, w // s, s)  # x(1,64,40,2,40,2)
+        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # x(1,2,2,64,40,40)
+        return x.view(b, c * s * s, h // s, w // s)  # x(1,256,40,40)
+
+
+class Expand(nn.Module):
+    # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
+    def __init__(self, gain=2):
+        """
+        Contract函数的还原函数,目的是将channel维度(缩小4倍)的数据扩展到W和H维度(扩大两倍)
+        :param gain:
+        :type gain:
+        """
+        super().__init__()
+        self.gain = gain
+
+    def forward(self, x):
+        b, c, h, w = x.size()  # assert C / s ** 2 == 0, 'Indivisible gain'
+        s = self.gain
+        x = x.view(b, s, s, c // s ** 2, h, w)  # x(1,2,2,16,80,80)
+        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # x(1,16,80,2,80,2)
+        return x.view(b, c // s ** 2, h * s, w * s)  # x(1,16,160,160)
+
+
+class Concat(nn.Module):
+    # Concatenate a list of tensors along dimension
+    def __init__(self, dimension=1):
+        """
+        按指定维度进行拼接
+        :param dimension:维度
+        :type dimension:
+        """
+        super().__init__()
+        self.d = dimension
+
+    def forward(self, x):
+        return torch.cat(x, self.d)
+
+
+class DetectMultiBackend(nn.Module): # YOLOv5 多类型模型推理
+    # YOLOv5 MultiBackend class for python inference on various backends
+    def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False):
+        # Usage:
+        #   PyTorch:              weights = *.pt
+        #   TorchScript:                    *.torchscript
+        #   ONNX Runtime:                   *.onnx
+        #   ONNX OpenCV DNN:                *.onnx with --dnn
+        #   OpenVINO:                       *.xml
+        #   CoreML:                         *.mlmodel
+        #   TensorRT:                       *.engine
+        #   TensorFlow SavedModel:          *_saved_model
+        #   TensorFlow GraphDef:            *.pb
+        #   TensorFlow Lite:                *.tflite
+        #   TensorFlow Edge TPU:            *_edgetpu.tflite
+        from models.experimental import attempt_download, attempt_load  # scoped to avoid circular import
+
+        super().__init__()
+        w = str(weights[0] if isinstance(weights, list) else weights) # 获取 weights的名称
+        pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w)  # get backend 返回模型的类型,如果模型属于该类则返回True
+        stride, names = 32, [f'class{i}' for i in range(1000)]  # assign defaults 自定义步长为32,类别为1000种
+        w = attempt_download(w)  # download if not local 下载权重文件,如果文件不存在
+        fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu'  # FP16
+        if data:  # data.yaml path (optional) 如果yaml文件存在则读取文件种的class_name
+            with open(data, errors='ignore') as f:
+                names = yaml.safe_load(f)['names']  # class names
+
+        if pt:  # PyTorch
+            model = attempt_load(weights if isinstance(weights, list) else w, map_location=device) # 加载权重文件
+            stride = max(int(model.stride.max()), 32)  # model stride #获取模型的下采样倍数(最小32倍)
+            names = model.module.names if hasattr(model, 'module') else model.names  # get class names 获取分类名称
+            model.half() if fp16 else model.float() # 全精度/半精度
+            self.model = model  # explicitly assign for to(), cpu(), cuda(), half()
+        elif jit:  # TorchScript
+            LOGGER.info(f'Loading {w} for TorchScript inference...')
+            extra_files = {'config.txt': ''}  # model metadata
+            model = torch.jit.load(w, _extra_files=extra_files)
+            model.half() if fp16 else model.float()
+            if extra_files['config.txt']:
+                d = json.loads(extra_files['config.txt'])  # extra_files dict
+                stride, names = int(d['stride']), d['names']
+        elif dnn:  # ONNX OpenCV DNN
+            LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
+            check_requirements(('opencv-python>=4.5.4',))
+            net = cv2.dnn.readNetFromONNX(w)
+        elif onnx:  # ONNX Runtime
+            LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
+            cuda = torch.cuda.is_available()
+            check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
+            import onnxruntime
+            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
+            session = onnxruntime.InferenceSession(w, providers=providers)
+            meta = session.get_modelmeta().custom_metadata_map  # metadata
+            if 'stride' in meta:
+                stride, names = int(meta['stride']), eval(meta['names'])
+        elif xml:  # OpenVINO
+            LOGGER.info(f'Loading {w} for OpenVINO inference...')
+            check_requirements(('openvino-dev',))  # requires openvino-dev: https://pypi.org/project/openvino-dev/
+            import openvino.inference_engine as ie
+            core = ie.IECore()
+            if not Path(w).is_file():  # if not *.xml
+                w = next(Path(w).glob('*.xml'))  # get *.xml file from *_openvino_model dir
+            network = core.read_network(model=w, weights=Path(w).with_suffix('.bin'))  # *.xml, *.bin paths
+            executable_network = core.load_network(network, device_name='CPU', num_requests=1)
+        elif engine:  # TensorRT
+            LOGGER.info(f'Loading {w} for TensorRT inference...')
+            import tensorrt as trt  # https://developer.nvidia.com/nvidia-tensorrt-download
+            check_version(trt.__version__, '7.0.0', hard=True)  # require tensorrt>=7.0.0
+            Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
+            logger = trt.Logger(trt.Logger.INFO)
+            with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
+                model = runtime.deserialize_cuda_engine(f.read())
+            bindings = OrderedDict()
+            fp16 = False  # default updated below
+            for index in range(model.num_bindings):
+                name = model.get_binding_name(index)
+                dtype = trt.nptype(model.get_binding_dtype(index))
+                shape = tuple(model.get_binding_shape(index))
+                data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
+                bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
+                if model.binding_is_input(index) and dtype == np.float16:
+                    fp16 = True
+            binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
+            context = model.create_execution_context()
+            batch_size = bindings['images'].shape[0]
+        elif coreml:  # CoreML
+            LOGGER.info(f'Loading {w} for CoreML inference...')
+            import coremltools as ct
+            model = ct.models.MLModel(w)
+        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
+            if saved_model:  # SavedModel
+                LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
+                import tensorflow as tf
+                keras = False  # assume TF1 saved_model
+                model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
+            elif pb:  # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
+                LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
+                import tensorflow as tf
+
+                def wrap_frozen_graph(gd, inputs, outputs):
+                    x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), [])  # wrapped
+                    ge = x.graph.as_graph_element
+                    return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
+
+                gd = tf.Graph().as_graph_def()  # graph_def
+                with open(w, 'rb') as f:
+                    gd.ParseFromString(f.read())
+                frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs="Identity:0")
+            elif tflite or edgetpu:  # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
+                try:  # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
+                    from tflite_runtime.interpreter import Interpreter, load_delegate
+                except ImportError:
+                    import tensorflow as tf
+                    Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
+                if edgetpu:  # Edge TPU https://coral.ai/software/#edgetpu-runtime
+                    LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
+                    delegate = {
+                        'Linux': 'libedgetpu.so.1',
+                        'Darwin': 'libedgetpu.1.dylib',
+                        'Windows': 'edgetpu.dll'}[platform.system()]
+                    interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
+                else:  # Lite
+                    LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
+                    interpreter = Interpreter(model_path=w)  # load TFLite model
+                interpreter.allocate_tensors()  # allocate
+                input_details = interpreter.get_input_details()  # inputs
+                output_details = interpreter.get_output_details()  # outputs
+            elif tfjs:
+                raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
+        self.__dict__.update(locals())  # assign all variables to self
+
+    def forward(self, im, augment=False, visualize=False, val=False):
+        # YOLOv5 MultiBackend inference YOLOv5支持不同模型的推理
+        b, ch, h, w = im.shape  # batch, channel, height, width
+        if self.pt:  # PyTorch
+            y = self.model(im, augment=augment, visualize=visualize)[0]
+        elif self.jit:  # TorchScript
+            y = self.model(im)[0]
+        elif self.dnn:  # ONNX OpenCV DNN
+            im = im.cpu().numpy()  # torch to numpy
+            self.net.setInput(im)
+            y = self.net.forward()
+        elif self.onnx:  # ONNX Runtime
+            im = im.cpu().numpy()  # torch to numpy
+            y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
+        elif self.xml:  # OpenVINO
+            im = im.cpu().numpy()  # FP32
+            desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW')  # Tensor Description
+            request = self.executable_network.requests[0]  # inference request
+            request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im))  # name=next(iter(request.input_blobs))
+            request.infer()
+            y = request.output_blobs['output'].buffer  # name=next(iter(request.output_blobs))
+        elif self.engine:  # TensorRT
+            assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
+            self.binding_addrs['images'] = int(im.data_ptr())
+            self.context.execute_v2(list(self.binding_addrs.values()))
+            y = self.bindings['output'].data
+        elif self.coreml:  # CoreML
+            im = im.permute(0, 2, 3, 1).cpu().numpy()  # torch BCHW to numpy BHWC shape(1,320,192,3)
+            im = Image.fromarray((im[0] * 255).astype('uint8'))
+            # im = im.resize((192, 320), Image.ANTIALIAS)
+            y = self.model.predict({'image': im})  # coordinates are xywh normalized
+            if 'confidence' in y:
+                box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]])  # xyxy pixels
+                conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
+                y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
+            else:
+                k = 'var_' + str(sorted(int(k.replace('var_', '')) for k in y)[-1])  # output key
+                y = y[k]  # output
+        else:  # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
+            im = im.permute(0, 2, 3, 1).cpu().numpy()  # torch BCHW to numpy BHWC shape(1,320,192,3)
+            if self.saved_model:  # SavedModel
+                y = (self.model(im, training=False) if self.keras else self.model(im)).numpy()
+            elif self.pb:  # GraphDef
+                y = self.frozen_func(x=self.tf.constant(im)).numpy()
+            else:  # Lite or Edge TPU
+                input, output = self.input_details[0], self.output_details[0]
+                int8 = input['dtype'] == np.uint8  # is TFLite quantized uint8 model
+                if int8:
+                    scale, zero_point = input['quantization']
+                    im = (im / scale + zero_point).astype(np.uint8)  # de-scale
+                self.interpreter.set_tensor(input['index'], im)
+                self.interpreter.invoke()
+                y = self.interpreter.get_tensor(output['index'])
+                if int8:
+                    scale, zero_point = output['quantization']
+                    y = (y.astype(np.float32) - zero_point) * scale  # re-scale
+            y[..., :4] *= [w, h, w, h]  # xywh normalized to pixels
+
+        if isinstance(y, np.ndarray):
+            y = torch.tensor(y, device=self.device)
+        return (y, []) if val else y
+
+    def warmup(self, imgsz=(1, 3, 640, 640)): # 模型预热推理
+        # Warmup model by running inference once
+        if any((self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb)):  # warmup types 检查模型类型
+            if self.device.type != 'cpu':  # only warmup GPU models
+                im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device)  # input 初始化全零矩阵作为模型的输入
+                for _ in range(2 if self.jit else 1):  #
+                    self.forward(im)  # warmup
+
+    @staticmethod
+    def model_type(p='path/to/model.pt'): # 根据模型的路径信息返回模型的类型
+        # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
+        from export import export_formats
+        suffixes = list(export_formats().Suffix) + ['.xml']  # export suffixes 获取YOLOv5模型支持格式
+        check_suffix(p, suffixes)  # checks 检查模型后缀
+        p = Path(p).name  # eliminate trailing separators 去除目录信息
+        pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, xml2 = (s in p for s in suffixes)
+        xml |= xml2  # *_openvino_model or *.xml
+        tflite &= not edgetpu  # *.tflite
+        return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
+
+
+class AutoShape(nn.Module): #自动调整shape
+    # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+    conf = 0.25  # NMS confidence threshold
+    iou = 0.45  # NMS IoU threshold
+    agnostic = False  # NMS class-agnostic
+    multi_label = False  # NMS multiple labels per box
+    classes = None  # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
+    max_det = 1000  # maximum number of detections per image
+    amp = False  # Automatic Mixed Precision (AMP) inference
+
+    def __init__(self, model):
+        super().__init__()
+        LOGGER.info('Adding AutoShape... ')
+        copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=())  # copy attributes
+        self.dmb = isinstance(model, DetectMultiBackend)  # DetectMultiBackend() instance
+        self.pt = not self.dmb or model.pt  # PyTorch model
+        self.model = model.eval()
+
+    def _apply(self, fn):
+        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
+        self = super()._apply(fn)
+        if self.pt:
+            m = self.model.model.model[-1] if self.dmb else self.model.model[-1]  # Detect()
+            m.stride = fn(m.stride)
+            m.grid = list(map(fn, m.grid))
+            if isinstance(m.anchor_grid, list):
+                m.anchor_grid = list(map(fn, m.anchor_grid))
+        return self
+
+    @torch.no_grad()
+    def forward(self, imgs, size=640, augment=False, profile=False):
+        # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
+        #   file:       imgs = 'data/images/zidane.jpg'  # str or PosixPath
+        #   URI:             = 'https://ultralytics.com/images/zidane.jpg'
+        #   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(640,1280,3)
+        #   PIL:             = Image.open('image.jpg') or ImageGrab.grab()  # HWC x(640,1280,3)
+        #   numpy:           = np.zeros((640,1280,3))  # HWC
+        #   torch:           = torch.zeros(16,3,320,640)  # BCHW (scaled to size=640, 0-1 values)
+        #   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images
+
+        t = [time_sync()]
+        p = next(self.model.parameters()) if self.pt else torch.zeros(1)  # for device and type
+        autocast = self.amp and (p.device.type != 'cpu')  # Automatic Mixed Precision (AMP) inference
+        if isinstance(imgs, torch.Tensor):  # torch
+            with amp.autocast(autocast):
+                return self.model(imgs.to(p.device).type_as(p), augment, profile)  # inference
+
+        # Pre-process
+        n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs])  # number, list of images
+        shape0, shape1, files = [], [], []  # image and inference shapes, filenames
+        for i, im in enumerate(imgs):
+            f = f'image{i}'  # filename
+            if isinstance(im, (str, Path)):  # filename or uri
+                im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
+                im = np.asarray(exif_transpose(im))
+            elif isinstance(im, Image.Image):  # PIL Image
+                im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
+            files.append(Path(f).with_suffix('.jpg').name)
+            if im.shape[0] < 5:  # image in CHW
+                im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
+            im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3)  # enforce 3ch input
+            s = im.shape[:2]  # HWC
+            shape0.append(s)  # image shape
+            g = (size / max(s))  # gain
+            shape1.append([y * g for y in s])
+            imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im)  # update
+        shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)]  # inf shape
+        x = [letterbox(im, shape1, auto=False)[0] for im in imgs]  # pad
+        x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2)))  # stack and BHWC to BCHW
+        x = torch.from_numpy(x).to(p.device).type_as(p) / 255  # uint8 to fp16/32
+        t.append(time_sync())
+
+        with amp.autocast(autocast):
+            # Inference
+            y = self.model(x, augment, profile)  # forward
+            t.append(time_sync())
+
+            # Post-process
+            y = non_max_suppression(y if self.dmb else y[0],
+                                    self.conf,
+                                    self.iou,
+                                    self.classes,
+                                    self.agnostic,
+                                    self.multi_label,
+                                    max_det=self.max_det)  # NMS
+            for i in range(n):
+                scale_coords(shape1, y[i][:, :4], shape0[i])
+
+            t.append(time_sync())
+            return Detections(imgs, y, files, t, self.names, x.shape)
+
+
+class Detections:
+    # YOLOv5 detections class for inference results
+    def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
+        super().__init__()
+        d = pred[0].device  # device
+        gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs]  # normalizations
+        self.imgs = imgs  # list of images as numpy arrays
+        self.pred = pred  # list of tensors pred[0] = (xyxy, conf, cls)
+        self.names = names  # class names
+        self.files = files  # image filenames
+        self.times = times  # profiling times
+        self.xyxy = pred  # xyxy pixels
+        self.xywh = [xyxy2xywh(x) for x in pred]  # xywh pixels
+        self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)]  # xyxy normalized
+        self.xywhn = [x / g for x, g in zip(self.xywh, gn)]  # xywh normalized
+        self.n = len(self.pred)  # number of images (batch size)
+        self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3))  # timestamps (ms)
+        self.s = shape  # inference BCHW shape
+
+    def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
+        crops = []
+        for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
+            s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} '  # string
+            if pred.shape[0]:
+                for c in pred[:, -1].unique():
+                    n = (pred[:, -1] == c).sum()  # detections per class
+                    s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string
+                if show or save or render or crop:
+                    annotator = Annotator(im, example=str(self.names))
+                    for *box, conf, cls in reversed(pred):  # xyxy, confidence, class
+                        label = f'{self.names[int(cls)]} {conf:.2f}'
+                        if crop:
+                            file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
+                            crops.append({
+                                'box': box,
+                                'conf': conf,
+                                'cls': cls,
+                                'label': label,
+                                'im': save_one_box(box, im, file=file, save=save)})
+                        else:  # all others
+                            annotator.box_label(box, label if labels else '', color=colors(cls))
+                    im = annotator.im
+            else:
+                s += '(no detections)'
+
+            im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im  # from np
+            if pprint:
+                LOGGER.info(s.rstrip(', '))
+            if show:
+                im.show(self.files[i])  # show
+            if save:
+                f = self.files[i]
+                im.save(save_dir / f)  # save
+                if i == self.n - 1:
+                    LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
+            if render:
+                self.imgs[i] = np.asarray(im)
+        if crop:
+            if save:
+                LOGGER.info(f'Saved results to {save_dir}\n')
+            return crops
+
+    def print(self):
+        self.display(pprint=True)  # print results
+        LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
+                    self.t)
+
+    def show(self, labels=True):
+        self.display(show=True, labels=labels)  # show results
+
+    def save(self, labels=True, save_dir='runs/detect/exp'):
+        save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True)  # increment save_dir
+        self.display(save=True, labels=labels, save_dir=save_dir)  # save results
+
+    def crop(self, save=True, save_dir='runs/detect/exp'):
+        save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
+        return self.display(crop=True, save=save, save_dir=save_dir)  # crop results
+
+    def render(self, labels=True):
+        self.display(render=True, labels=labels)  # render results
+        return self.imgs
+
+    def pandas(self):
+        # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
+        new = copy(self)  # return copy
+        ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name'  # xyxy columns
+        cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name'  # xywh columns
+        for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
+            a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)]  # update
+            setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
+        return new
+
+    def tolist(self):
+        # return a list of Detections objects, i.e. 'for result in results.tolist():'
+        r = range(self.n)  # iterable
+        x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
+        # for d in x:
+        #    for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
+        #        setattr(d, k, getattr(d, k)[0])  # pop out of list
+        return x
+
+    def __len__(self):
+        return self.n
+
+
+class Classify(nn.Module):
+    # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
+    def __init__(self, c1, c2, k=1, s=1, p=None, g=1):  # ch_in, ch_out, kernel, stride, padding, groups
+        super().__init__()
+        self.aap = nn.AdaptiveAvgPool2d(1)  # to x(b,c1,1,1)
+        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g)  # to x(b,c2,1,1)
+        self.flat = nn.Flatten()
+
+    def forward(self, x):
+        z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1)  # cat if list
+        return self.flat(self.conv(z))  # flatten to x(b,c2)

+ 122 - 0
code/yolov5/models/experimental.py

@@ -0,0 +1,122 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Experimental modules
+"""
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from models.common import Conv
+from utils.downloads import attempt_download
+
+
+class CrossConv(nn.Module):
+    # Cross Convolution Downsample
+    def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
+        # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
+        super().__init__()
+        c_ = int(c2 * e)  # hidden channels
+        self.cv1 = Conv(c1, c_, (1, k), (1, s))
+        self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
+        self.add = shortcut and c1 == c2
+
+    def forward(self, x):
+        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class Sum(nn.Module):
+    # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
+    def __init__(self, n, weight=False):  # n: number of inputs
+        super().__init__()
+        self.weight = weight  # apply weights boolean
+        self.iter = range(n - 1)  # iter object
+        if weight:
+            self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True)  # layer weights
+
+    def forward(self, x):
+        y = x[0]  # no weight
+        if self.weight:
+            w = torch.sigmoid(self.w) * 2
+            for i in self.iter:
+                y = y + x[i + 1] * w[i]
+        else:
+            for i in self.iter:
+                y = y + x[i + 1]
+        return y
+
+
+class MixConv2d(nn.Module):
+    # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
+    def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):  # ch_in, ch_out, kernel, stride, ch_strategy
+        super().__init__()
+        n = len(k)  # number of convolutions
+        if equal_ch:  # equal c_ per group
+            i = torch.linspace(0, n - 1E-6, c2).floor()  # c2 indices
+            c_ = [(i == g).sum() for g in range(n)]  # intermediate channels
+        else:  # equal weight.numel() per group
+            b = [c2] + [0] * n
+            a = np.eye(n + 1, n, k=-1)
+            a -= np.roll(a, 1, axis=1)
+            a *= np.array(k) ** 2
+            a[0] = 1
+            c_ = np.linalg.lstsq(a, b, rcond=None)[0].round()  # solve for equal weight indices, ax = b
+
+        self.m = nn.ModuleList([
+            nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
+        self.bn = nn.BatchNorm2d(c2)
+        self.act = nn.SiLU()
+
+    def forward(self, x):
+        return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
+
+
+class Ensemble(nn.ModuleList):
+    # Ensemble of models
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x, augment=False, profile=False, visualize=False):
+        y = []
+        for module in self:
+            y.append(module(x, augment, profile, visualize)[0])
+        # y = torch.stack(y).max(0)[0]  # max ensemble
+        # y = torch.stack(y).mean(0)  # mean ensemble
+        y = torch.cat(y, 1)  # nms ensemble
+        return y, None  # inference, train output
+
+
+def attempt_load(weights, map_location=None, inplace=True, fuse=True):
+    from models.yolo import Detect, Model
+
+    # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
+    model = Ensemble() # 初始化
+    for w in weights if isinstance(weights, list) else [weights]:
+        ckpt = torch.load(attempt_download(w), map_location=map_location)  # load 读取权重信息
+        ckpt = (ckpt.get('ema') or ckpt['model']).float()  # FP32 model
+        model.append(ckpt.fuse().eval() if fuse else ckpt.eval())  # fused or un-fused model in eval mode
+
+    # Compatibility updates 检查模型每个网络结构与torch的兼容性
+    for m in model.modules():
+        t = type(m)
+        if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
+            m.inplace = inplace  # torch 1.7.0 compatibility
+            if t is Detect:
+                if not isinstance(m.anchor_grid, list):  # new Detect Layer compatibility
+                    delattr(m, 'anchor_grid')
+                    setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
+        elif t is Conv:
+            m._non_persistent_buffers_set = set()  # torch 1.6.0 compatibility
+        elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
+            m.recompute_scale_factor = None  # torch 1.11.0 compatibility
+
+    if len(model) == 1:
+        return model[-1]  # return model
+    else:
+        print(f'Ensemble created with {weights}\n')
+        for k in 'names', 'nc', 'yaml':
+            setattr(model, k, getattr(model[0], k))
+        model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride  # max stride
+        assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
+        return model  # return ensemble

+ 59 - 0
code/yolov5/models/hub/anchors.yaml

@@ -0,0 +1,59 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+# Default anchors for COCO data
+
+
+# P5 -------------------------------------------------------------------------------------------------------------------
+# P5-640:
+anchors_p5_640:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+
+# P6 -------------------------------------------------------------------------------------------------------------------
+# P6-640:  thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11,  21,19,  17,41,  43,32,  39,70,  86,64,  65,131,  134,130,  120,265,  282,180,  247,354,  512,387
+anchors_p6_640:
+  - [9,11,  21,19,  17,41]  # P3/8
+  - [43,32,  39,70,  86,64]  # P4/16
+  - [65,131,  134,130,  120,265]  # P5/32
+  - [282,180,  247,354,  512,387]  # P6/64
+
+# P6-1280:  thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27,  44,40,  38,94,  96,68,  86,152,  180,137,  140,301,  303,264,  238,542,  436,615,  739,380,  925,792
+anchors_p6_1280:
+  - [19,27,  44,40,  38,94]  # P3/8
+  - [96,68,  86,152,  180,137]  # P4/16
+  - [140,301,  303,264,  238,542]  # P5/32
+  - [436,615,  739,380,  925,792]  # P6/64
+
+# P6-1920:  thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41,  67,59,  57,141,  144,103,  129,227,  270,205,  209,452,  455,396,  358,812,  653,922,  1109,570,  1387,1187
+anchors_p6_1920:
+  - [28,41,  67,59,  57,141]  # P3/8
+  - [144,103,  129,227,  270,205]  # P4/16
+  - [209,452,  455,396,  358,812]  # P5/32
+  - [653,922,  1109,570,  1387,1187]  # P6/64
+
+
+# P7 -------------------------------------------------------------------------------------------------------------------
+# P7-640:  thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11,  13,30,  29,20,  30,46,  61,38,  39,92,  78,80,  146,66,  79,163,  149,150,  321,143,  157,303,  257,402,  359,290,  524,372
+anchors_p7_640:
+  - [11,11,  13,30,  29,20]  # P3/8
+  - [30,46,  61,38,  39,92]  # P4/16
+  - [78,80,  146,66,  79,163]  # P5/32
+  - [149,150,  321,143,  157,303]  # P6/64
+  - [257,402,  359,290,  524,372]  # P7/128
+
+# P7-1280:  thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22,  54,36,  32,77,  70,83,  138,71,  75,173,  165,159,  148,334,  375,151,  334,317,  251,626,  499,474,  750,326,  534,814,  1079,818
+anchors_p7_1280:
+  - [19,22,  54,36,  32,77]  # P3/8
+  - [70,83,  138,71,  75,173]  # P4/16
+  - [165,159,  148,334,  375,151]  # P5/32
+  - [334,317,  251,626,  499,474]  # P6/64
+  - [750,326,  534,814,  1079,818]  # P7/128
+
+# P7-1920:  thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34,  81,55,  47,115,  105,124,  207,107,  113,259,  247,238,  222,500,  563,227,  501,476,  376,939,  749,711,  1126,489,  801,1222,  1618,1227
+anchors_p7_1920:
+  - [29,34,  81,55,  47,115]  # P3/8
+  - [105,124,  207,107,  113,259]  # P4/16
+  - [247,238,  222,500,  563,227]  # P5/32
+  - [501,476,  376,939,  749,711]  # P6/64
+  - [1126,489,  801,1222,  1618,1227]  # P7/128

+ 51 - 0
code/yolov5/models/hub/yolov3-spp.yaml

@@ -0,0 +1,51 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# darknet53 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [32, 3, 1]],  # 0
+   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
+   [-1, 1, Bottleneck, [64]],
+   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4
+   [-1, 2, Bottleneck, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 5-P3/8
+   [-1, 8, Bottleneck, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 7-P4/16
+   [-1, 8, Bottleneck, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P5/32
+   [-1, 4, Bottleneck, [1024]],  # 10
+  ]
+
+# YOLOv3-SPP head
+head:
+  [[-1, 1, Bottleneck, [1024, False]],
+   [-1, 1, SPP, [512, [5, 9, 13]]],
+   [-1, 1, Conv, [1024, 3, 1]],
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, Conv, [1024, 3, 1]],  # 15 (P5/32-large)
+
+   [-2, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, Conv, [512, 3, 1]],  # 22 (P4/16-medium)
+
+   [-2, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P3
+   [-1, 1, Bottleneck, [256, False]],
+   [-1, 2, Bottleneck, [256, False]],  # 27 (P3/8-small)
+
+   [[27, 22, 15], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
+  ]

+ 41 - 0
code/yolov5/models/hub/yolov3-tiny.yaml

@@ -0,0 +1,41 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,14, 23,27, 37,58]  # P4/16
+  - [81,82, 135,169, 344,319]  # P5/32
+
+# YOLOv3-tiny backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [16, 3, 1]],  # 0
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 1-P1/2
+   [-1, 1, Conv, [32, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 3-P2/4
+   [-1, 1, Conv, [64, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 5-P3/8
+   [-1, 1, Conv, [128, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 7-P4/16
+   [-1, 1, Conv, [256, 3, 1]],
+   [-1, 1, nn.MaxPool2d, [2, 2, 0]],  # 9-P5/32
+   [-1, 1, Conv, [512, 3, 1]],
+   [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]],  # 11
+   [-1, 1, nn.MaxPool2d, [2, 1, 0]],  # 12
+  ]
+
+# YOLOv3-tiny head
+head:
+  [[-1, 1, Conv, [1024, 3, 1]],
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, Conv, [512, 3, 1]],  # 15 (P5/32-large)
+
+   [-2, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Conv, [256, 3, 1]],  # 19 (P4/16-medium)
+
+   [[19, 15], 1, Detect, [nc, anchors]],  # Detect(P4, P5)
+  ]

+ 51 - 0
code/yolov5/models/hub/yolov3.yaml

@@ -0,0 +1,51 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# darknet53 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [32, 3, 1]],  # 0
+   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
+   [-1, 1, Bottleneck, [64]],
+   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4
+   [-1, 2, Bottleneck, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 5-P3/8
+   [-1, 8, Bottleneck, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 7-P4/16
+   [-1, 8, Bottleneck, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 9-P5/32
+   [-1, 4, Bottleneck, [1024]],  # 10
+  ]
+
+# YOLOv3 head
+head:
+  [[-1, 1, Bottleneck, [1024, False]],
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, Conv, [1024, 3, 1]],
+   [-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, Conv, [1024, 3, 1]],  # 15 (P5/32-large)
+
+   [-2, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Bottleneck, [512, False]],
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, Conv, [512, 3, 1]],  # 22 (P4/16-medium)
+
+   [-2, 1, Conv, [128, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P3
+   [-1, 1, Bottleneck, [256, False]],
+   [-1, 2, Bottleneck, [256, False]],  # 27 (P3/8-small)
+
+   [[27, 22, 15], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
+  ]

+ 48 - 0
code/yolov5/models/hub/yolov5-bifpn.yaml

@@ -0,0 +1,48 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+
+# Parameters
+nc: 80  # number of classes
+depth_multiple: 1.0  # model depth multiple
+width_multiple: 1.0  # layer channel multiple
+anchors:
+  - [10,13, 16,30, 33,23]  # P3/8
+  - [30,61, 62,45, 59,119]  # P4/16
+  - [116,90, 156,198, 373,326]  # P5/32
+
+# YOLOv5 v6.0 backbone
+backbone:
+  # [from, number, module, args]
+  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
+   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
+   [-1, 3, C3, [128]],
+   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
+   [-1, 6, C3, [256]],
+   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
+   [-1, 9, C3, [512]],
+   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
+   [-1, 3, C3, [1024]],
+   [-1, 1, SPPF, [1024, 5]],  # 9
+  ]
+
+# YOLOv5 v6.0 BiFPN head
+head:
+  [[-1, 1, Conv, [512, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
+   [-1, 3, C3, [512, False]],  # 13
+
+   [-1, 1, Conv, [256, 1, 1]],
+   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
+   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)
+
+   [-1, 1, Conv, [256, 3, 2]],
+   [[-1, 14, 6], 1, Concat, [1]],  # cat P4 <--- BiFPN change
+   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)
+
+   [-1, 1, Conv, [512, 3, 2]],
+   [[-1, 10], 1, Concat, [1]],  # cat head P5
+   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
+
+   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
+  ]

Some files were not shown because too many files changed in this diff