Przeglądaj źródła

predict_box内无线段时,选box内点组成线段最长的 两个点组成的线段返回

xue50 8 miesięcy temu
rodzic
commit
c2cffb6b1a

+ 97 - 75
aaa.py

@@ -1,77 +1,99 @@
-import torch
-from torchvision.utils import draw_bounding_boxes
-from torchvision import transforms
-import matplotlib.pyplot as plt
-import numpy as np
-
-
-def c(score):
-    # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
-    return (1, 0, 0) if score > 0.9 else (0, 1, 0)
-
-
-def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
-    # 假设的后处理函数,用于过滤线段
-    nlines = []
-    nscores = []
-    for line, score in zip(lines, scores):
-        if score >= min_score:
-            nlines.append(line)
-            nscores.append(score)
-    return np.array(nlines), np.array(nscores)
-
-
-def show_line(img, pred, epoch, writer):
-    im = img.permute(1, 2, 0).cpu().numpy()
-
-    # 绘制边界框
-    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
-                                      colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
-
-    H = pred[-1]['wires']
-    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
-    scores = H["score"][0].cpu().numpy()
+# import torch
+# from torchvision.utils import draw_bounding_boxes
+# from torchvision import transforms
+# import matplotlib.pyplot as plt
+# import numpy as np
+#
+#
+# def c(score):
+#     # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
+#     return (1, 0, 0) if score > 0.9 else (0, 1, 0)
+#
+#
+# def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
+#     # 假设的后处理函数,用于过滤线段
+#     nlines = []
+#     nscores = []
+#     for line, score in zip(lines, scores):
+#         if score >= min_score:
+#             nlines.append(line)
+#             nscores.append(score)
+#     return np.array(nlines), np.array(nscores)
+#
+#
+# def show_line(img, pred, epoch, writer):
+#     im = img.permute(1, 2, 0).cpu().numpy()
+#
+#     # 绘制边界框
+#     boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
+#                                       colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
+#
+#     H = pred[-1]['wires']
+#     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+#     scores = H["score"][0].cpu().numpy()
+#
+#     print(f"Lines before deduplication: {len(lines)}")
+#
+#     # 移除重复的线段
+#     for i in range(1, len(lines)):
+#         if (lines[i] == lines[0]).all():
+#             lines = lines[:i]
+#             scores = scores[:i]
+#             break
+#
+#     print(f"Lines after deduplication: {len(lines)}")
+#
+#     # 后处理线段以移除重叠的线段
+#     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+#     nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+#
+#     print(f"Lines after postprocessing: {len(nlines)}")
+#
+#     # 创建一个新的图像并绘制线段和边界框
+#     fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
+#     ax.imshow(boxed_image)
+#     ax.set_axis_off()
+#     plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+#     plt.margins(0, 0)
+#     plt.gca().xaxis.set_major_locator(plt.NullLocator())
+#     plt.gca().yaxis.set_major_locator(plt.NullLocator())
+#
+#     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+#     for (a, b), s in zip(nlines, nscores):
+#         if s < 0.85:  # 调整阈值以筛选显示的线段
+#             continue
+#         plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+#         plt.scatter(a[1], a[0], **PLTOPTS)
+#         plt.scatter(b[1], b[0], **PLTOPTS)
+#
+#     plt.tight_layout()
+#     fig.canvas.draw()
+#     image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
+#         fig.canvas.get_width_height()[::-1] + (3,))
+#     plt.close()
+#     img2 = transforms.ToTensor()(image_from_plot)
+#
+#     writer.add_image("output_with_boxes_and_lines", img2, epoch)
+#     print("Image with boxes and lines added to TensorBoard.")
 
-    print(f"Lines before deduplication: {len(lines)}")
 
-    # 移除重复的线段
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
-
-    print(f"Lines after deduplication: {len(lines)}")
-
-    # 后处理线段以移除重叠的线段
-    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
-
-    print(f"Lines after postprocessing: {len(nlines)}")
-
-    # 创建一个新的图像并绘制线段和边界框
-    fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
-    ax.imshow(boxed_image)
-    ax.set_axis_off()
-    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-    plt.margins(0, 0)
-    plt.gca().xaxis.set_major_locator(plt.NullLocator())
-    plt.gca().yaxis.set_major_locator(plt.NullLocator())
-
-    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    for (a, b), s in zip(nlines, nscores):
-        if s < 0.85:  # 调整阈值以筛选显示的线段
-            continue
-        plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-        plt.scatter(a[1], a[0], **PLTOPTS)
-        plt.scatter(b[1], b[0], **PLTOPTS)
-
-    plt.tight_layout()
-    fig.canvas.draw()
-    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
-        fig.canvas.get_width_height()[::-1] + (3,))
-    plt.close()
-    img2 = transforms.ToTensor()(image_from_plot)
-
-    writer.add_image("output_with_boxes_and_lines", img2, epoch)
-    print("Image with boxes and lines added to TensorBoard.")
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.ndimage import gaussian_filter
+import random
+# 假设我们有一些关键点位置
+keypoints = [(0, 0), (70, 80), (90, 30)]
+
+# 创建一个空白的热图
+heatmap = np.zeros((100, 100))
+
+# 将关键点位置添加到热图中
+for point in keypoints:
+    y, x = point
+    heatmap[y, x] = random.random()
+    # heatmap[y, x] = 1  # 假设置信度为1
+
+print(heatmap)
+# 使用高斯滤波平滑热图
+heatmap_smooth = gaussian_filter(heatmap, sigma=1)
+print(heatmap_smooth)

+ 0 - 0
models/line_detect/test_train2.py → models/line_detect/111.py


+ 32 - 0
models/line_detect/dataset_LD.py

@@ -27,6 +27,37 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
 
 from tools.presets import DetectionPresetTrain
 
+def line_boxes1(target):
+    boxs = []
+    lines = target.cpu().numpy() * 4
+
+    if len(lines) > 0 and not (lines[0] == 0).all():
+        for i, ((a, b)) in enumerate(lines):
+            if i > 0 and (lines[i] == lines[0]).all():
+                break
+            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]ÎÞÃ÷È·´óС
+            if a[-1]==0. and b[-1]==0.:
+                continue
+
+            if a[1] > b[1]:
+                ymax = a[1] + 10
+                ymin = b[1] - 10
+            else:
+                ymin = a[1] - 10
+                ymax = b[1] + 10
+            if a[0] > b[0]:
+                xmax = a[0] + 10
+                xmin = b[0] - 10
+            else:
+                xmin = a[0] - 10
+                xmax = b[0] + 10
+            boxs.append([ymin, xmin, ymax, xmax])
+
+    # if boxs == []:
+    #     print(target)
+
+    return torch.tensor(boxs)
+
 
 class WirePointDataset(BaseDataset):
     def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
@@ -120,6 +151,7 @@ class WirePointDataset(BaseDataset):
         # return wire_labels, target
         target["wires"] = wire_labels
         target["boxes"] = line_boxes(target)
+        # target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
         target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
         # print(f'target["labels"]:{ target["labels"]}')
         # print(f'boxes:{target["boxes"].shape}')

+ 5 - 0
models/line_detect/line_predictor.py

@@ -19,6 +19,8 @@ from models.config.config_tool import read_yaml
 import numpy as np
 import torch.nn.functional as F
 
+from scipy.ndimage import gaussian_filter
+
 FEATURE_DIM = 8
 
 def non_maximum_suppression(a):
@@ -388,6 +390,9 @@ class LineRCNNPredictor(nn.Module):
             )
             print(f'feat  shape:{feat.shape}')
 
+            # lmap = gaussian_filter(lmap, sigma=1)
+            # lmap = torch.from_numpy(gaussian_filter(lmap.cpu().numpy(), sigma=1)).to('cuda:0')
+
             line = torch.cat([xyu[:, None], xyv[:, None]], 1)
             # print(f'line:{line.shape}')
             n_channel, row, col = lmap.shape

+ 394 - 85
models/line_detect/predict2.py

@@ -1,6 +1,162 @@
+# import time
+#
+# from models.line_detect.postprocess import show_predict
+# import os
+#
+# import torch
+# from PIL import Image
+# import matplotlib.pyplot as plt
+# import matplotlib as mpl
+# import numpy as np
+# from models.line_detect.line_net import linenet_resnet50_fpn
+# from torchvision import transforms
+# from rtree import index
+# # from models.wirenet.postprocess import postprocess
+# from models.wirenet.postprocess import postprocess
+#
+# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+#
+#
+# def load_best_model(model, save_path, device):
+#     if os.path.exists(save_path):
+#         checkpoint = torch.load(save_path, map_location=device)
+#         model.load_state_dict(checkpoint['model_state_dict'])
+#         # if optimizer is not None:
+#         #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+#         epoch = checkpoint['epoch']
+#         loss = checkpoint['loss']
+#         print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+#     else:
+#         print(f"No saved model found at {save_path}")
+#     return model
+#
+#
+# def box_line_optimized(pred):
+#     # 创建R-tree索引
+#     idx = index.Index()
+#
+#     # 将所有线段添加到R-tree中
+#     lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+#     scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+#
+#     # 提取并处理所有线段
+#     for idx_line in range(lines.shape[1]):  # 遍历2500条线段
+#         line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
+#         x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
+#         y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
+#         x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
+#         y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
+#         idx.insert(idx_line, (x_min, y_min, x_max, y_max))
+#
+#     for idx_box, box_ in enumerate(pred[0:-1]):
+#         box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
+#         line_ = []
+#         score_ = []
+#
+#         for i in box:
+#             score_max = 0.0
+#             tmp = [[0.0, 0.0], [0.0, 0.0]]
+#
+#             # 获取与当前box可能相交的所有线段
+#             possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+#
+#             for j in possible_matches:
+#                 line_j = lines[0, j].cpu().numpy() / 128 * 512
+#                 if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y
+#                         line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+#                         line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+#                         line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+#
+#                     if scores[j] > score_max:
+#                         tmp = line_j
+#                         score_max = scores[j]
+#
+#             line_.append(tmp)
+#             score_.append(score_max)
+#
+#         processed_list = torch.tensor(line_)
+#         pred[idx_box]['line'] = processed_list
+#
+#         processed_s_list = torch.tensor(score_)
+#         pred[idx_box]['line_score'] = processed_s_list
+#
+#     return pred
+#
+# # def box_line_(pred):
+# #     for idx, box_ in enumerate(pred[0:-1]):
+# #         box = box_['boxes']  # 是一个tensor
+# #         line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+# #         score = pred[-1]['wires']['score'][idx]
+# #         line_ = []
+# #         score_ = []
+# #
+# #         for i in box:
+# #             score_max = 0.0
+# #             tmp = [[0.0, 0.0], [0.0, 0.0]]
+# #
+# #             for j in range(len(line)):
+# #                 if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+# #                         line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+# #                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+# #                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+# #
+# #                     if score[j] > score_max:
+# #                         tmp = line[j]
+# #                         score_max = score[j]
+# #             line_.append(tmp)
+# #             score_.append(score_max)
+# #         processed_list = torch.tensor(line_)
+# #         pred[idx]['line'] = processed_list
+# #
+# #         processed_s_list = torch.tensor(score_)
+# #         pred[idx]['line_score'] = processed_s_list
+# #     return pred
+#
+#
+# def predict(pt_path, model, img):
+#     model = load_best_model(model, pt_path, device)
+#
+#     model.eval()
+#
+#     if isinstance(img, str):
+#         img = Image.open(img).convert("RGB")
+#
+#     transform = transforms.ToTensor()
+#     img_tensor = transform(img)
+#
+#     with torch.no_grad():
+#         t_start = time.time()
+#         predictions = model([img_tensor.to(device)])
+#         t_end=time.time()
+#         print(f'predict used:{t_end-t_start}')
+#         # print(f'predictions:{predictions}')
+#         boxes=predictions[0]['boxes'].shape
+#         lines=predictions[-1]['wires']['lines'].shape
+#         lines_scores=predictions[-1]['wires']['score'].shape
+#         print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
+#     t_start=time.time()
+#     pred = box_line_optimized(predictions)
+#     t_end=time.time()
+#     print(f'matched boxes and lines used:{t_end - t_start}')
+#     # print(f'pred:{pred[0]}')
+#     show_predict(img_tensor, pred, t_start)
+#
+#
+# if __name__ == '__main__':
+#     t_start = time.time()
+#     print(f'start to predict:{t_start}')
+#     model = linenet_resnet50_fpn().to(device)
+#     pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
+#     img_path = r"I:\datasets\wirenet_1000\images\val\00035148_0.png"
+#     predict(pt_path, model, img_path)
+#     t_end = time.time()
+#     # print(f'predict used:{t_end - t_start}')
+
+
 import time
 
-from models.line_detect.postprocess import show_predict
+import skimage
+
 import os
 
 import torch
@@ -10,9 +166,10 @@ import matplotlib as mpl
 import numpy as np
 from models.line_detect.line_net import linenet_resnet50_fpn
 from torchvision import transforms
-from rtree import index
+
 # from models.wirenet.postprocess import postprocess
 from models.wirenet.postprocess import postprocess
+from rtree import index
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
@@ -23,95 +180,233 @@ def load_best_model(model, save_path, device):
         model.load_state_dict(checkpoint['model_state_dict'])
         # if optimizer is not None:
         #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-        epoch = checkpoint['epoch']
-        loss = checkpoint['loss']
-        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        # epoch = checkpoint['epoch']
+        # loss = checkpoint['loss']
+        # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
     else:
         print(f"No saved model found at {save_path}")
     return model
 
 
-def box_line_optimized(pred):
-    # 创建R-tree索引
-    idx = index.Index()
+def box_line_(imgs, pred, length=False):    # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line_ = []
+        score_ = []
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+            score_.append(score_max)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
+    return pred
 
-    # 将所有线段添加到R-tree中
-    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
-    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+# box内无线段时,选box内点组成线段最长的 两个点组成的线段返回
+def box_line1(imgs, pred, length=False):    # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
 
-    # 提取并处理所有线段
-    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
-        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
-        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
-        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
-        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
-        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
-        idx.insert(idx_line, (x_min, y_min, x_max, y_max))
+    points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
 
-    for idx_box, box_ in enumerate(pred[0:-1]):
-        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
         line_ = []
         score_ = []
-
         for i in box:
             score_max = 0.0
             tmp = [[0.0, 0.0], [0.0, 0.0]]
 
-            # 获取与当前box可能相交的所有线段
-            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            # 如果 box 内无线段,则通过点坐标找最长线段
+            if score_max == 0.0:  # 说明 box 内无线段
+                box_points = [
+                    [x, y] for x, y in points
+                    if i[0] <= y <= i[2] and i[1] <= x <= i[3]
+                ]
+
+                if len(box_points) >= 2:  # 至少需要两个点才能组成线段
+                    max_distance = 0.0
+                    longest_segment = [[0.0, 0.0], [0.0, 0.0]]
 
-            for j in possible_matches:
-                line_j = lines[0, j].cpu().numpy() / 128 * 512
-                if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y
-                        line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
-                        line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
-                        line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                    # 找出 box 内点组成的最长线段
+                    for p1 in box_points:
+                        for p2 in box_points:
+                            if p1 != p2:
+                                distance = np.linalg.norm(np.array(p1) - np.array(p2))
+                                if distance > max_distance:
+                                    max_distance = distance
+                                    longest_segment = [p1, p2]
 
-                    if scores[j] > score_max:
-                        tmp = line_j
-                        score_max = scores[j]
+                    tmp = longest_segment
+                    score_max = 0.0  # 默认分数为 0.0
 
             line_.append(tmp)
             score_.append(score_max)
-
         processed_list = torch.tensor(line_)
-        pred[idx_box]['line'] = processed_list
+        pred[idx]['line'] = processed_list
 
         processed_s_list = torch.tensor(score_)
-        pred[idx_box]['line_score'] = processed_s_list
-
+        pred[idx]['line_score'] = processed_s_list
     return pred
 
-# def box_line_(pred):
-#     for idx, box_ in enumerate(pred[0:-1]):
-#         box = box_['boxes']  # 是一个tensor
-#         line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
-#         score = pred[-1]['wires']['score'][idx]
-#         line_ = []
-#         score_ = []
-#
-#         for i in box:
-#             score_max = 0.0
-#             tmp = [[0.0, 0.0], [0.0, 0.0]]
-#
-#             for j in range(len(line)):
-#                 if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
-#                         line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
-#                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
-#                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
-#
-#                     if score[j] > score_max:
-#                         tmp = line[j]
-#                         score_max = score[j]
-#             line_.append(tmp)
-#             score_.append(score_max)
-#         processed_list = torch.tensor(line_)
-#         pred[idx]['line'] = processed_list
-#
-#         processed_s_list = torch.tensor(score_)
-#         pred[idx]['line_score'] = processed_s_list
-#     return pred
+def show_box(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    # print(len(col))
+    im = imgs.permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
 
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    for idx, box in enumerate(boxes):
+        # if box_scores[idx] < 0.7:
+        #     continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+
+    t_end = time.time()
+    print(f'show_box used:{t_end - t_start}')
+    plt.show()
+
+
+def show_predict(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    im = imgs.permute(1, 2, 0)  # 处理为 [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0 or line_score >= 0:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+
+    plt.show()
+
+def show_line(imgs, pred, t_start):
+    im = imgs.permute(1, 2, 0)
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    # print(pred[-1]['wires']['score'])
+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    t1 = time.time()
+    print(f't1:{t1 - t_start}')
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
+    print(f'lines num:{len(line)}')
+    t2 = time.time()
+    print(f't1:{t2 - t1}')
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    for idx, (a, b) in enumerate(line):
+        # if line_score[idx] < 0.7:
+        #     continue
+        ax.scatter(a[1], a[0], c='#871F78', s=2)
+        ax.scatter(b[1], b[0], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+
+    t_end = time.time()
+    print(f'show_line used:{t_end - t_start}')
+
+    plt.show()
 
 def predict(pt_path, model, img):
     model = load_best_model(model, pt_path, device)
@@ -122,32 +417,46 @@ def predict(pt_path, model, img):
         img = Image.open(img).convert("RGB")
 
     transform = transforms.ToTensor()
-    img_tensor = transform(img)
+    img_tensor = transform(img)  # [3, 512, 512]
+
+    # 将图像调整为512x512大小
+    t_start = time.time()
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_ = torch.tensor(im_resized).permute(2, 0, 1)
+    t_end = time.time()
+    print(f'switch img used:{t_end - t_start}')
 
     with torch.no_grad():
-        t_start = time.time()
-        predictions = model([img_tensor.to(device)])
-        t_end=time.time()
-        print(f'predict used:{t_end-t_start}')
-        # print(f'predictions:{predictions}')
-        boxes=predictions[0]['boxes'].shape
-        lines=predictions[-1]['wires']['lines'].shape
-        lines_scores=predictions[-1]['wires']['score'].shape
-        print(f'predictions boxes:{boxes},lines:{lines},lines_scores:{lines_scores}')
-    t_start=time.time()
-    pred = box_line_optimized(predictions)
-    t_end=time.time()
-    print(f'matched boxes and lines used:{t_end - t_start}')
-    # print(f'pred:{pred[0]}')
-    show_predict(img_tensor, pred, t_start)
+        predictions = model([img_.to(device)])
+        print(predictions)
+    t_end1 = time.time()
+    print(f'model test used:{t_end1 - t_end}')
+
+    # show_line_optimized(img_, predictions, t_start)   # 只画线
+    show_line(img_, predictions, t_end1)
+    t_end2 = time.time()
+    show_box(img_, predictions, t_end2)   # 只画kuang
+    # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
+    # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
+
+    t_start = time.time()
+    pred = box_line_(img_, predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+
+    show_predict(img_, pred, t_start)
 
 
 if __name__ == '__main__':
     t_start = time.time()
     print(f'start to predict:{t_start}')
     model = linenet_resnet50_fpn().to(device)
-    pt_path = r"F:\BaiduNetdiskDownload\resnet50_best_e8.pth"
-    img_path = r"I:\datasets\wirenet_1000\images\val\00035148_0.png"
+    # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
+    # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
+    pt_path = r"C:\Users\m2337\Downloads\best_e150.pth"
+    # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
+    img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
     predict(pt_path, model, img_path)
     t_end = time.time()
-    # print(f'predict used:{t_end - t_start}')
+    print(f'predict used:{t_end - t_start}')

+ 1 - 1
models/line_detect/roi_heads.py

@@ -253,7 +253,7 @@ def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
         result["preds"]["junts"] = torch.cat(
             [jcs[i][1] for i in range(n_batch)]
         )
-    print(f'predic result:{result}')
+    # print(f'predic result:{result}')
     return result
 
 

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: I:/datasets/0322_suanzaisheng
+  datadir: D:\all\1Desktop\20250320data\0322_
 #  datadir: I:\datasets\wirenet_1000
   resume_from:
   num_workers: 8