Browse Source

版本3 box和line绘图完成

xue50 3 months ago
parent
commit
3b887ecb72

+ 1 - 0
.gitignore

@@ -1,5 +1,6 @@
 .idea
 *.pt
+*.pth
 *.log
 *.onnx
 runs

+ 178 - 3
models/line_detect/postprocess.py

@@ -68,6 +68,7 @@ def box_line_(pred):
     return pred
 
 
+# box与line匹配后画在一张图上,不设置阈值,直接画
 def show_(imgs, pred, epoch, writer):
     col = [
         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
@@ -120,6 +121,7 @@ def show_(imgs, pred, epoch, writer):
     writer.add_image("all", img2, epoch)
 
 
+# box与line匹配后画在一张图上,设置阈值
 def show_predict(imgs, pred, t_start):
     col = [
         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
@@ -144,7 +146,7 @@ def show_predict(imgs, pred, t_start):
         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
     ]
     print(len(col))
-    im = imgs.permute(1, 2, 0)
+    im = imgs.permute(1, 2, 0)  # 处理为 [512, 512, 3]
     boxes = pred[0]['boxes'].cpu().numpy()
     box_scores = pred[0]['scores'].cpu().numpy()
     lines = pred[0]['line'].cpu().numpy()
@@ -161,11 +163,184 @@ def show_predict(imgs, pred, t_start):
         if box_score > 0.7 and line_score > 0.9:
             ax.add_patch(
                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
-            ax.scatter(a[1], a[0], c='#871F78', s=2)
-            ax.scatter(b[1], b[0], c='#871F78', s=2)
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
             ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
             idx = idx + 1
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')
 
     plt.show()
+
+
+# 下面的都没有进行box与line的一一匹配
+# 只画线,设阈值
+def show_line(imgs, pred, t_start):
+
+    im = imgs.permute(1, 2, 0)
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    # print(pred[-1]['wires']['score'])
+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    for idx, (a, b) in enumerate(line):
+        if line_score[idx] < 0.9:
+            continue
+        ax.scatter(a[1], a[0], c='#871F78', s=2)
+        ax.scatter(b[1], b[0], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+
+    t_end = time.time()
+    print(f'show_line used:{t_end - t_start}')
+
+    plt.show()
+
+
+# 只画框,设阈值
+def show_box(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    # print(len(col))
+    im = imgs.permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    for idx, box in enumerate(boxes):
+        if box_scores[idx] < 0.7:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+
+    t_end = time.time()
+    print(f'show_box used:{t_end - t_start}')
+
+    plt.show()
+
+
+# 将show_line与show_box合并,传入参数确定显示框还是线  都不显示,输出原图
+def show_box_or_line(imgs, pred, show_line=False, show_box=False):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    # print(len(col))
+    im = imgs.permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    if show_box:
+        for idx, box in enumerate(boxes):
+            x0, y0, x1, y1 = box
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+
+    if show_line:
+        for idx, (a, b) in enumerate(line):
+            ax.scatter(a[1], a[0], c='#871F78', s=2)
+            ax.scatter(b[1], b[0], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+
+    plt.show()
+
+
+# 将show_line与show_box合并,传入参数确定显示框还是线  一起画
+def show_box_and_line(imgs, pred, show_line=False, show_box=False):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    # print(len(col))
+    im = imgs.permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+
+    # 可视化预测结
+    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
+
+    if show_box:
+        axs[0].imshow(np.array(im))
+        for idx, box in enumerate(boxes):
+            x0, y0, x1, y1 = box
+            axs[0].add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+        axs[0].set_title('Boxes')
+
+    if show_line:
+        axs[1].imshow(np.array(im))
+        for idx, (a, b) in enumerate(line):
+            axs[1].scatter(a[1], a[0], c='#871F78', s=2)
+            axs[1].scatter(b[1], b[0], c='#871F78', s=2)
+            axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        axs[1].set_title('Lines')
+
+    # 调整子图之间的距离,防止标题和标签重叠
+    plt.tight_layout()
+    plt.show()

+ 10 - 8
models/line_detect/predict.py

@@ -1,5 +1,6 @@
 import os
 
+import skimage
 import torch
 from PIL import Image
 import matplotlib.pyplot as plt
@@ -57,11 +58,6 @@ def show_line(img, pred):
     boxes = pred[0]['boxes'].cpu().numpy()
     boxes_scores = pred[0]['scores'].cpu().numpy()
 
-    # for box in boxes:
-    #     x0, y0, x1, y1 = box
-    #     rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
-    #     ax.add_patch(rect)  # 将矩形添加到 Axes 对象上
-
     for b, s in zip(boxes, boxes_scores):
         # print(f'box:{b}, box_score:{s}')
         if s < 0.7:
@@ -115,16 +111,22 @@ def predict(pt_path, model, img):
     transform = transforms.ToTensor()
     img_tensor = transform(img)
 
+    # 将图像调整为512x512大小
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_tensor = torch.tensor(im_resized).permute(2, 0, 1)
+
     with torch.no_grad():
         predictions = model([img_tensor])
-        print(predictions[0])
+        # print(predictions[0])
 
     show_line(img_tensor, predictions)
 
 
 if __name__ == '__main__':
     model = linenet_resnet50_fpn()
-    pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
     # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
-    img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    img_path = r'C:\Users\m2337\Desktop\49.jpg'
     predict(pt_path, model, img_path)

+ 84 - 8
models/line_detect/predict2.py

@@ -1,6 +1,8 @@
 import time
 
-from models.line_detect.postprocess import show_predict
+import skimage
+
+from models.line_detect.postprocess import show_predict, show_line, show_box, show_box_or_line, show_box_and_line
 import os
 
 import torch
@@ -13,6 +15,7 @@ from torchvision import transforms
 
 # from models.wirenet.postprocess import postprocess
 from models.wirenet.postprocess import postprocess
+from rtree import index
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -30,6 +33,7 @@ def load_best_model(model, save_path, device):
         print(f"No saved model found at {save_path}")
     return model
 
+
 def box_line_(pred):
     for idx, box_ in enumerate(pred[0:-1]):
         box = box_['boxes']  # 是一个tensor
@@ -61,6 +65,58 @@ def box_line_(pred):
     return pred
 
 
+def box_line_optimized(pred):
+    # 创建R-tree索引
+    idx = index.Index()
+
+    # 将所有线段添加到R-tree中
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+
+    # 提取并处理所有线段
+    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
+        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
+        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
+        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
+        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
+        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
+        idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))
+
+    for idx_box, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            # 获取与当前box可能相交的所有线段
+            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+
+            for j in possible_matches:
+                line_j = lines[0, j].cpu().numpy() / 128 * 512
+                if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y
+                        line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                        line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                        line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+
+                    if scores[j] > score_max:
+                        tmp = line_j
+                        score_max = scores[j]
+
+            line_.append(tmp)
+            score_.append(score_max)
+
+        processed_list = torch.tensor(line_)
+        pred[idx_box]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx_box]['line_score'] = processed_s_list
+
+    return pred
+
+
 def predict(pt_path, model, img):
     model = load_best_model(model, pt_path, device)
 
@@ -70,24 +126,44 @@ def predict(pt_path, model, img):
         img = Image.open(img).convert("RGB")
 
     transform = transforms.ToTensor()
-    img_tensor = transform(img)
+    img_tensor = transform(img)  # [3, 512, 512]
+
+    # img_ = img_tensor
+
+    # 将图像调整为512x512大小
+    t_start = time.time()
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_ = torch.tensor(im_resized).permute(2, 0, 1)
+    t_end = time.time()
+    print(f'switch img used:{t_end - t_start}')
 
     with torch.no_grad():
-        predictions = model([img_tensor.to(device)])
+        predictions = model([img_.to(device)])
         # print(predictions)
 
-    pred = box_line_(predictions)
-    # print(f'pred:{pred[0]}')
-    show_predict(img_tensor, pred, t_start)
+    show_line(img_, predictions, t_start)   # 只画线
+    # show_box(img_, predictions, t_start)   # 只画kuang
+    # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
+    # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
+
+    # t_start = time.time()
+    # # pred = box_line_optimized(predictions)
+    # pred = box_line_(predictions)
+    # t_end = time.time()
+    # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+
+    # show_predict(img_, pred, t_start)
 
 
 if __name__ == '__main__':
     t_start = time.time()
     print(f'start to predict:{t_start}')
     model = linenet_resnet50_fpn().to(device)
-    pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
-    img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
+    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
     # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    img_path = r'C:\Users\m2337\Desktop\9.jpg'
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')

+ 221 - 0
models/line_detect/predict3.py

@@ -0,0 +1,221 @@
+# 并行计算
+
+import time
+
+import skimage
+
+from models.line_detect.postprocess import show_predict
+import os
+
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision import transforms
+
+# from models.wirenet.postprocess import postprocess
+from models.wirenet.postprocess import postprocess
+from rtree import index
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+import multiprocessing as mp
+
+
+def process_box(box, lines, scores, idx):
+    line_ = []
+    score_ = []
+
+    for i in box:
+        score_max = 0.0
+        tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+        # 获取与当前box可能相交的所有线段
+        possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+
+        for j in possible_matches:
+            line_j = lines[0, j].cpu().numpy() / 128 * 512
+            if (line_j[0][0] >= i[0] and line_j[1][0] >= i[0] and
+                line_j[0][0] <= i[2] and line_j[1][0] <= i[2] and
+                line_j[0][1] >= i[1] and line_j[1][1] >= i[1] and
+                line_j[0][1] <= i[3] and line_j[1][1] <= i[3]):
+
+                if scores[j] > score_max:
+                    tmp = line_j
+                    score_max = scores[j]
+
+        line_.append(tmp)
+        score_.append(score_max)
+
+    return torch.tensor(line_), torch.tensor(score_)
+
+def box_line_optimized1(pred):
+    # 创建R-tree索引
+    idx = index.Index()
+
+    # 将所有线段添加到R-tree中
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+
+    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
+        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
+        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
+        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
+        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
+        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
+        idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))
+
+    # 准备要处理的数据
+    data_to_process = []
+    for box_ in pred[0:-1]:
+        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
+        data_to_process.append((box, lines, scores, idx))
+
+    # 使用 Pool 创建进程池并行处理数据
+    with mp.Pool(processes=mp.cpu_count()) as pool:  # 根据 CPU 核心数创建进程池
+        results = pool.starmap(process_box, data_to_process)
+
+    # 将结果放回原始 pred 中
+    for idx_box, (processed_list, processed_s_list) in enumerate(results):
+        pred[idx_box]['line'] = processed_list
+        pred[idx_box]['line_score'] = processed_s_list
+
+    return pred
+
+
+def load_best_model(model, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        # if optimizer is not None:
+        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model
+
+
+def box_line_(pred):
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        score = pred[-1]['wires']['score'][idx]
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+            score_.append(score_max)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
+    return pred
+
+
+def box_line_optimized(pred):
+    # 创建R-tree索引
+    idx = index.Index()
+
+    # 将所有线段添加到R-tree中
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+
+    # 提取并处理所有线段
+    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
+        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
+        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
+        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
+        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
+        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
+        idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))
+
+    for idx_box, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            # 获取与当前box可能相交的所有线段
+            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+
+            for j in possible_matches:
+                line_j = lines[0, j].cpu().numpy() / 128 * 512
+                if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y
+                        line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                        line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                        line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+
+                    if scores[j] > score_max:
+                        tmp = line_j
+                        score_max = scores[j]
+
+            line_.append(tmp)
+            score_.append(score_max)
+
+        processed_list = torch.tensor(line_)
+        pred[idx_box]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx_box]['line_score'] = processed_s_list
+
+    return pred
+
+
+def predict(pt_path, model, img):
+    model = load_best_model(model, pt_path, device)
+
+    model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)  # [3, 512, 512]
+
+    # 将图像调整为512x512大小
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_ = torch.tensor(im_resized).permute(2, 0, 1)
+
+    with torch.no_grad():
+        predictions = model([img_.to(device)])
+        # print(predictions)
+
+    pred = box_line_optimized1(predictions)
+    # print(pred)
+    # pred = box_line_(predictions)
+    show_predict(img_, pred, t_start)
+
+
+if __name__ == '__main__':
+    t_start = time.time()
+    print(f'start to predict:{t_start}')
+    model = linenet_resnet50_fpn().to(device)
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
+    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
+    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    img_path = r'C:\Users\m2337\Desktop\49.jpg'
+    predict(pt_path, model, img_path)
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')

+ 3 - 1
models/line_detect/train.yaml

@@ -1,11 +1,13 @@
 io:
   logdir: logs/
-  datadir: /root/autodl-tmp/wirenet_rgb_gray
+  datadir: D:\python\PycharmProjects\data
 #  datadir: I:\datasets\wirenet_1000
   resume_from:
   num_workers: 8
   tensorboard_port: 6000
   validation_interval: 300
+  batch_size: 4
+  batch_size_eval: 2
 
 optim:
   name: Adam

+ 141 - 22
models/line_detect/trainer.py

@@ -1,4 +1,3 @@
-
 import os
 import time
 from datetime import datetime
@@ -28,7 +27,11 @@ def _loss(losses):
         total_loss += loss
 
     return total_loss
+
+
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
 def move_to_device(data, device):
     if isinstance(data, (list, tuple)):
         return type(data)(move_to_device(item, device) for item in data)
@@ -39,13 +42,14 @@ def move_to_device(data, device):
     else:
         return data  # 对于非张量类型的数据不做任何改变
 
+
 class Trainer(BaseTrainer):
     def __init__(self, model=None,
                  dataset=None,
                  device='cuda',
                  **kwargs):
 
-        super().__init__(model,dataset,device,**kwargs)
+        super().__init__(model, dataset, device, **kwargs)
 
     def move_to_device(self, data, device):
         if isinstance(data, (list, tuple)):
@@ -57,7 +61,7 @@ class Trainer(BaseTrainer):
         else:
             return data  # 对于非张量类型的数据不做任何改变
 
-    def load_best_model(self,model, optimizer, save_path, device):
+    def load_best_model(self, model, optimizer, save_path, device):
         if os.path.exists(save_path):
             checkpoint = torch.load(save_path, map_location=device)
             model.load_state_dict(checkpoint['model_state_dict'])
@@ -84,30 +88,151 @@ class Trainer(BaseTrainer):
         except Exception as e:
             print(f"TensorBoard logging error: {e}")
 
-    def train_cfg(self, model:BaseModel, cfg):
+    def train_cfg(self, model: BaseModel, cfg):
         # cfg = r'./config/wireframe.yaml'
         cfg = read_yaml(cfg)
         print(f'cfg:{cfg}')
         # print(cfg['n_dyn_negl'])
         self.train(model, **cfg)
 
+    # def train(self, model, **kwargs):
+    #     dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
+    #     train_sampler = torch.utils.data.RandomSampler(dataset_train)
+    #     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    #     train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
+    #     train_collate_fn = utils.collate_fn_wirepoint
+    #     data_loader_train = torch.utils.data.DataLoader(
+    #         dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
+    #     )
+    #
+    #     dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
+    #     val_sampler = torch.utils.data.RandomSampler(dataset_val)
+    #     # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
+    #     val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
+    #     val_collate_fn = utils.collate_fn_wirepoint
+    #     data_loader_val = torch.utils.data.DataLoader(
+    #         dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
+    #     )
+    #
+    #     train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+    #     wts_path = os.path.join(train_result_ptath, 'weights')
+    #     tb_path = os.path.join(train_result_ptath, 'logs')
+    #     writer = SummaryWriter(tb_path)
+    #
+    #     optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
+    #     # writer = SummaryWriter(kwargs['io']['logdir'])
+    #     model.to(device)
+    #
+    #
+    #
+    #     # # 加载权重
+    #     # save_path = 'logs/pth/best_model.pth'
+    #     # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
+    #
+    #     # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
+    #     # os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
+    #     last_model_path = os.path.join(wts_path, 'last.pth')
+    #     best_model_path = os.path.join(wts_path, 'best.pth')
+    #     global_step = 0
+    #
+    #     for epoch in range(kwargs['optim']['max_epoch']):
+    #         print(f"epoch:{epoch}")
+    #         total_train_loss = 0.0
+    #
+    #         model.train()
+    #
+    #         for imgs, targets in data_loader_train:
+    #             imgs = move_to_device(imgs, device)
+    #             targets=move_to_device(targets,device)
+    #             # print(f'imgs:{len(imgs)}')
+    #             # print(f'targets:{len(targets)}')
+    #             losses = model(imgs, targets)
+    #             loss = _loss(losses)
+    #             total_train_loss += loss.item()
+    #             optimizer.zero_grad()
+    #             loss.backward()
+    #             optimizer.step()
+    #             self.writer_loss(writer, losses, global_step)
+    #             global_step+=1
+    #
+    #
+    #         avg_train_loss = total_train_loss / len(data_loader_train)
+    #         if epoch == 0:
+    #             best_loss = avg_train_loss;
+    #
+    #         writer.add_scalar('loss/train', avg_train_loss, epoch)
+    #
+    #
+    #         if os.path.exists(f'{wts_path}/last.pt'):
+    #             os.remove(f'{wts_path}/last.pt')
+    #         # torch.save(model.state_dict(), f'{wts_path}/last.pt')
+    #         save_last_model(model,last_model_path,epoch,optimizer)
+    #         best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
+    #
+    #         model.eval()
+    #         with torch.no_grad():
+    #             for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+    #                 t_start = time.time()
+    #                 print(f'start to predict:{t_start}')
+    #                 pred = model(self.move_to_device(imgs, self.device))
+    #                 t_end = time.time()
+    #                 print(f'predict used:{t_end - t_start}')
+    #                 if batch_idx == 0:
+    #                     show_line(imgs[0], pred, epoch, writer)
+    #                 break
+
     def train(self, model, **kwargs):
-        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
+        default_params = {
+            'io': {
+                'logdir': 'logs /',
+                'datadir': '/ root / autodl - tmp / wirenet_rgb_gray',
+                'num_workers': 8,
+                'tensorboard_port': 6000,
+                'validation_interval': 300,
+                'batch_size': 4,
+                'batch_size_eval': 2,
+            },
+            'optim':{
+                'name': 'Adam',
+                'lr': 4.0e-4,
+                'amsgrad': True,
+                'weight_decay': 1.0e-4,
+                'max_epoch': 90000000,
+                'lr_decay_epoch': 10,
+            },
+        }
+
+        # 更新默认参数
+        for key, value in kwargs.items():
+            if key in default_params:
+                default_params[key] = value
+            else:
+                raise ValueError(f"Unknown argument: {key}")
+
+        # 解析参数
+        dataset_path = default_params['io']['datadir']
+        num_workers = default_params['io']['num_workers']
+        batch_size_train = default_params['io']['batch_size']
+        batch_size_eval = default_params['io']['batch_size_eval']
+        epochs = default_params['optim']['max_epoch']
+        lr = default_params['optim']['lr']
+
+        dataset_train = WirePointDataset(dataset_path=dataset_path, dataset_type='train')
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=batch_size_train, drop_last=True)
         train_collate_fn = utils.collate_fn_wirepoint
         data_loader_train = torch.utils.data.DataLoader(
-            dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=num_workers, collate_fn=train_collate_fn
         )
 
-        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
+        dataset_val = WirePointDataset(dataset_path=dataset_path, dataset_type='val')
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=batch_size_eval, drop_last=True)
         val_collate_fn = utils.collate_fn_wirepoint
         data_loader_val = torch.utils.data.DataLoader(
-            dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
+            dataset_val, batch_sampler=val_batch_sampler, num_workers=num_workers, collate_fn=val_collate_fn
         )
 
         train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
@@ -115,12 +240,10 @@ class Trainer(BaseTrainer):
         tb_path = os.path.join(train_result_ptath, 'logs')
         writer = SummaryWriter(tb_path)
 
-        optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
+        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
         # writer = SummaryWriter(kwargs['io']['logdir'])
         model.to(device)
 
-
-
         # # 加载权重
         # save_path = 'logs/pth/best_model.pth'
         # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
@@ -131,7 +254,7 @@ class Trainer(BaseTrainer):
         best_model_path = os.path.join(wts_path, 'best.pth')
         global_step = 0
 
-        for epoch in range(kwargs['optim']['max_epoch']):
+        for epoch in range(epochs):
             print(f"epoch:{epoch}")
             total_train_loss = 0.0
 
@@ -139,7 +262,7 @@ class Trainer(BaseTrainer):
 
             for imgs, targets in data_loader_train:
                 imgs = move_to_device(imgs, device)
-                targets=move_to_device(targets,device)
+                targets = move_to_device(targets, device)
                 # print(f'imgs:{len(imgs)}')
                 # print(f'targets:{len(targets)}')
                 losses = model(imgs, targets)
@@ -149,8 +272,7 @@ class Trainer(BaseTrainer):
                 loss.backward()
                 optimizer.step()
                 self.writer_loss(writer, losses, global_step)
-                global_step+=1
-
+                global_step += 1
 
             avg_train_loss = total_train_loss / len(data_loader_train)
             if epoch == 0:
@@ -158,12 +280,11 @@ class Trainer(BaseTrainer):
 
             writer.add_scalar('loss/train', avg_train_loss, epoch)
 
-
             if os.path.exists(f'{wts_path}/last.pt'):
                 os.remove(f'{wts_path}/last.pt')
             # torch.save(model.state_dict(), f'{wts_path}/last.pt')
-            save_last_model(model,last_model_path,epoch,optimizer)
-            best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
+            save_last_model(model, last_model_path, epoch, optimizer)
+            best_loss = save_best_model(model, best_model_path, epoch, avg_train_loss, best_loss, optimizer)
 
             model.eval()
             with torch.no_grad():
@@ -176,5 +297,3 @@ class Trainer(BaseTrainer):
                     if batch_idx == 0:
                         show_line(imgs[0], pred, epoch, writer)
                     break
-
-