|
@@ -1,4 +1,4 @@
|
|
|
-├── deep-hough-transform-master #深度霍夫直线检测算法 https://arxiv.org/abs/2003.04676
|
|
|
+├── deep-hough-transform-master #深度霍夫直线检测训练算法 https://arxiv.org/abs/2003.04676
|
|
|
│ ├── basic_ops.py
|
|
|
│ ├── chamfer_distance
|
|
|
│ │ ├── chamfer_distance.cpp
|
|
@@ -65,7 +65,7 @@
|
|
|
│ └── utils.py
|
|
|
├── model #训练后的模型存放路径
|
|
|
│ ├── cb_dht #
|
|
|
-│ │ ├── dht_r50_fpn_sel-c9a29d40.pth
|
|
|
+│ │ ├── dht_r50_fpn_sel_c9a29d40.pth
|
|
|
│ │ └── dht_r50_nkl_d97b97138.pth
|
|
|
│ ├── cb_mask #皮带mask检测模型
|
|
|
│ │ └── best_dice_loss.pth
|
|
@@ -111,7 +111,7 @@
|
|
|
│ │ └── test.mp4
|
|
|
|
|
|
----
|
|
|
-####<span id="basic_ops.py">basic_ops.py</span>
|
|
|
+####<span id="basic_ops.py">basic_ops.py</span>
|
|
|
class Line(object) #直线信息类
|
|
|
class LineAnnotation(object) #直线注释类(没有引用该类)
|
|
|
def line2mask(size, lines) #在mask图像中画出直线
|
|
@@ -120,7 +120,7 @@ def int2arc(k, num_directions) #int转弧度
|
|
|
def arc2int(theta, num_directions) #弧度转int(没有引用该方法)
|
|
|
|
|
|
----
|
|
|
-####<span id="utils.py">utils.py</span>
|
|
|
+####<span id="utils.py">utils.py</span>
|
|
|
def draw_line(y, x, angle, image, color=(0, 0, 255), num_directions=24) #根据x、y坐标和角度画线(没有引用该方法)
|
|
|
def convert_line_to_hough(line, size=(32, 32)) #line2hough方法的附属方法
|
|
|
def line2hough(line, numAngle, numRho, size=(32, 32)) #将线转换为霍夫表示方式(类似于极坐标参数表示方法)(没有引用该方法)
|
|
@@ -136,24 +136,24 @@ def overflow(x, size=400) #图像大小判断
|
|
|
def edge_align(coords, filename, size, division=9) #图像边缘对齐(没有引用该方法)
|
|
|
|
|
|
----
|
|
|
-####<span id="network.py">network.py</span>
|
|
|
+####<span id="network.py">network.py</span>
|
|
|
class Net(nn.Module) #模型选择(resnet18、resnet50、resnet101、resnext50、vgg16、mobilenetv2、res2net50、mobilenetv2)
|
|
|
|
|
|
----
|
|
|
-####<span id="edge_utils.py">edge_utils.py</span>
|
|
|
+####<span id="edge_utils.py">edge_utils.py</span>
|
|
|
def predict_single_image(model, image, size) #预测单张图像
|
|
|
|
|
|
----
|
|
|
-####<span id="edge_detector.py">edge_detector.py</span>
|
|
|
+####<span id="edge_detector.py">edge_detector.py</span>
|
|
|
class EdgeDetector(object) #边缘检测类
|
|
|
|
|
|
----
|
|
|
-####<span id="dht.py">dht.py</span>
|
|
|
+####<span id="dht.py">dht.py</span>
|
|
|
class DHT_Layer(nn.Module) # 深度霍夫变换网络结构
|
|
|
class DHT(nn.Module) 深度霍夫变换的引用
|
|
|
|
|
|
----
|
|
|
-####<span id="geometry_utils.py">geometry_utils.py</span>
|
|
|
+####<span id="geometry_utils.py">geometry_utils.py</span>
|
|
|
def _calc_abc_from_line_2d(point1, point2) #_get_line_cross_point函数的附属函数
|
|
|
def _get_line_cross_point(line1, line2) #通过两条直线的点坐标集合直接求出两条直线的交点坐标(不需要求出直线方程。line1:[point1,point2])
|
|
|
def kb_get_point(line, height, width) #通过直线方程和图片大小确定该直线与图像边缘相交的点坐标
|
|
@@ -166,28 +166,28 @@ def lines_up_or_up_down(line_1, line_2, width, height) #判断上下边缘线
|
|
|
def _list_remove(lists, a) #列表移除元素
|
|
|
|
|
|
----
|
|
|
-####<span id="belt.py">belt.py</span>
|
|
|
+####<span id="belt.py">belt.py</span>
|
|
|
class CbMaskDetection(object) #皮带mask检测类
|
|
|
|
|
|
----
|
|
|
-####<span id="run.py">run.py</span>
|
|
|
+####<span id="run.py">run.py</span>
|
|
|
class Detector(object) #画图、检测集合类
|
|
|
def predict_image(detector, image_path) #图像检测
|
|
|
def predict_video(video_path, detector) #视频检测
|
|
|
def main() #主函数
|
|
|
|
|
|
----
|
|
|
-####<span id="roller.py">roller.py</span>
|
|
|
+####<span id="roller.py">roller.py</span>
|
|
|
class RollerDetection(object) #托辊检测类
|
|
|
|
|
|
----
|
|
|
-####<span id="build_contextpath.py">build_contextpath.py</span>
|
|
|
+####<span id="build_contextpath.py">build_contextpath.py</span>
|
|
|
class resnet18(torch.nn.Module) #resnet18模型
|
|
|
class resnet101(torch.nn.Module) #resnet101模型
|
|
|
def build_contextpath(name) #模型选择
|
|
|
|
|
|
----
|
|
|
-####<span id="mask_component_utils.py">mask_component_utils.py</span>
|
|
|
+####<span id="mask_component_utils.py">mask_component_utils.py</span>
|
|
|
class RandomCrop(object) #在随机位置上裁剪给定的PIL图像(没有引用该类)
|
|
|
class OHEM_CrossEntroy_Loss(nn.Module) #交叉熵损失(没有引用该类)
|
|
|
def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1,max_iter=300, power=0.9) #学习率的多项衰减(没有引用该方法)
|
|
@@ -204,7 +204,7 @@ def cal_miou(miou_list, csv_path) #(没有引用该方法)
|
|
|
def group_weight(weight_group, module, norm_layer, lr) #(没有引用该方法)
|
|
|
|
|
|
----
|
|
|
-####<span id="build_BiSeNet.py">build_BiSeNet.py</span>
|
|
|
+####<span id="build_BiSeNet.py">build_BiSeNet.py</span>
|
|
|
class ConvBlock(torch.nn.Module) #卷积块
|
|
|
class Spatial_path(torch.nn.Module) #Spatial Path部分,用于保存空间信息 https://zhuanlan.zhihu.com/p/47250633
|
|
|
class AttentionRefinementModule(torch.nn.Module) #Context Path部分,U形结构
|
|
@@ -212,7 +212,7 @@ class FeatureFusionModule(torch.nn.Module) #特征融合模块
|
|
|
class BiSeNet(torch.nn.Module) #BiSeNet模型
|
|
|
|
|
|
----
|
|
|
-####<span id="bbox_common_utils.py">bbox_common_utils.py</span>
|
|
|
+####<span id="bbox_common_utils.py">bbox_common_utils.py</span>
|
|
|
def cal_dis_point_line(point, line_point) #求托辊关键点到皮带边缘线的距离
|
|
|
def fitting_straight_line_function(x, k, b) #直线方程表达式
|
|
|
def ignore_center(self, height, width, x_min, y_min, x_max, y_max) # 检查托辊框是否越界
|
|
@@ -222,7 +222,7 @@ def get_roller_line(self, point_side_label, coordinate_list, top_down_label_dict
|
|
|
def get_roller_info(self, top_down_label_dict, coordinate_list, width, height) 主方法
|
|
|
|
|
|
---
|
|
|
-###运行程序
|
|
|
+###运行程序
|
|
|
根据 https://github.com/Hanqer/deep-hough-transform 教程安装deep-hough
|
|
|
运行 run.py
|
|
|
|