浏览代码

修改model.predict

xue50 7 月之前
父节点
当前提交
14a3221d39
共有 6 个文件被更改,包括 620 次插入24 次删除
  1. 1 0
      .gitignore
  2. 11 11
      models/line_detect/111.py
  3. 474 0
      models/line_detect/aaa.py
  4. 22 1
      models/line_detect/line_net.py
  5. 102 2
      models/line_detect/predict.py
  6. 10 10
      models/line_detect/predict2.py

+ 1 - 0
.gitignore

@@ -1,6 +1,7 @@
 .idea
 *.pt
 *.pth
+*.png
 
 *.log
 *.onnx

+ 11 - 11
models/line_detect/111.py

@@ -48,6 +48,7 @@ class Trainer(BaseTrainer):
                  **kwargs):
         super().__init__(model, dataset, device, **kwargs)
         self.freeze_config = freeze_config or {}  # 默认冻结配置为空
+
     def move_to_device(self, data, device):
         if isinstance(data, (list, tuple)):
             return type(data)(self.move_to_device(item, device) for item in data)
@@ -57,11 +58,12 @@ class Trainer(BaseTrainer):
             return data.to(device)
         else:
             return data  # 对于非张量类型的数据不做任何改变
+
     def freeze_params(self, model):
         """根据配置冻结模型参数"""
         default_config = {
             'backbone': True,  # 冻结 backbone
-            'rpn': False,      # 不冻结 rpn
+            'rpn': False,  # 不冻结 rpn
             'roi_heads': {
                 'box_head': False,
                 'box_predictor': False,
@@ -226,11 +228,11 @@ class Trainer(BaseTrainer):
 import torch
 
 from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
+
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
-
     # model = LineNet('line_net.yaml')
-    model=linenet_resnet50_fpn().to(device)
+    model = linenet_resnet50_fpn().to(device)
     # model=linenet_resnet18_fpn()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
@@ -238,12 +240,10 @@ if __name__ == '__main__':
     # trainer = Trainer()
     # trainer.train_cfg(model=model, cfg='train.yaml')
     #
-    pt_path = r"E:\projects\tmp\MultiVisionModels\models\line_detect\train_results\20250424_162124\weights\best.pth"
-    img_path = r"I:\datasets\4_23jiagonggongjian\images\val\2025-04-23-08-52-00_SaveRightImage.png"
-
-    model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
-
-
-
-
+    # pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
+    # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
+    # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
 
+    model = model.load_best_model(model, r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth")
+    img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
+    model.predict1(model, img_path, type=1, threshold=0, save_path=None, show=True)

+ 474 - 0
models/line_detect/aaa.py

@@ -0,0 +1,474 @@
+# from fastapi import FastAPI, File, UploadFile, HTTPException
+# from fastapi.responses import FileResponse
+# from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.wirenet.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+
+# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_detect.boxline import show_box
+
+# ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """¼ÓÔØÄ£ÐͲ¢·µ»ØÄ£ÐÍʵÀý"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """Ô¤´¦ÀíÉÏ´«µÄͼƬ"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1),img
+
+def save_plot(output_path: str):
+    """±£´æÍ¼Ïñ²¢¹Ø±Õ»æÍ¼"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+def get_colors():
+    """·µ»ØÒ»×éÔ¤¶¨ÒåµÄÑÕÉ«Áбí"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def process_box(box, lines, scores):
+    """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
+    valid_lines = []  # ´æ´¢ÓÐЧµÄÏß¶Î
+    valid_scores = []  # ´æ´¢ÓÐЧµÄ·ÖÊý
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # ±éÀúËùÓÐÏß¶Î
+        for j in range(lines.shape[1]):
+            # line_j = lines[0, j].cpu().numpy() / 128 * 512
+            line_j = lines[0, j].cpu().numpy()
+            # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # ¼ÆËãÏ߶㤶È
+                length = np.linalg.norm(line_j[0] - line_j[1])
+                # length = scores[j].cpu().numpy()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        # Èç¹ûÕÒµ½ÓÐЧµÄÏ߶Σ¬ÔòÌí¼Óµ½½á¹ûÖÐ
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)  # ʹÓÃÏ߶㤶È×÷Ϊ·ÖÊý
+
+        else:
+            valid_lines.append([[0.0,0.0],[0.0,0.0]])
+            valid_scores.append(0.0)  # ʹÓÃÏß¶ÎÖÃÐŶÈ×÷Ϊ·ÖÊý
+        # print(f'valid_lines:{valid_lines}')
+        # print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+
+# def box_line_optimized_parallel(imgs, pred):  # ĬÈÏÖÃÐŶÈ
+#     im = imgs.permute(1, 2, 0).cpu().numpy()
+#     line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+#     line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+#
+#     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+#     line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+#     for idx, box_ in enumerate(pred[0:-1]):
+#         box = box_['boxes']
+#
+#         line_ = []
+#         score_ = []
+#
+#         for i in box:
+#             score_max = 0.0
+#             tmp = [[0.0, 0.0], [0.0, 0.0]]
+#
+#             for j in range(len(line)):
+#                 if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+#                         line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+#                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+#                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+#
+#                     if score[j] > score_max:
+#                         tmp = line[j]
+#                         score_max = score[j]
+#             line_.append(tmp)
+#             score_.append(score_max)
+#         processed_list = torch.tensor(line_)
+#         pred[idx]['line'] = processed_list
+#
+#         processed_s_list = torch.tensor(score_)
+#         pred[idx]['line_score'] = processed_s_list
+#     del pred[-1]
+#     return pred
+
+def box_line_optimized_parallel(imgs, pred, length=False):    # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+    # print(f'line_data:{line_data}')
+
+    points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
+
+    is_all_zeros = np.all(line_data == 0.0)
+    if is_all_zeros:
+        for idx, box_ in enumerate(pred[0:-1]):
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+            processed_list = torch.tensor(tmp)
+            pred[idx]['line'] = processed_list
+
+            processed_s_list = torch.tensor(score_max)
+            pred[idx]['line_score'] = processed_s_list
+    else:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+        for idx, box_ in enumerate(pred[0:-2]):
+            box = box_['boxes']  # 是一个tensor
+            line_ = []
+            score_ = []
+            for i in box:
+                score_max = 0.0
+                tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+                for j in range(len(line)):
+                    if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                            line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                            line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                            line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                        if score[j] > score_max:
+                            tmp = line[j]
+                            score_max = score[j]
+                # 如果 box 内无线段,则通过点坐标找最长线段
+                if score_max == 0.0:  # 说明 box 内无线段
+                    box_points = [
+                        [x, y] for x, y in points
+                        if i[0] <= y <= i[2] and i[1] <= x <= i[3]
+                    ]
+
+                    if len(box_points) >= 2:  # 至少需要两个点才能组成线段
+                        max_distance = 0.0
+                        longest_segment = [[0.0, 0.0], [0.0, 0.0]]
+
+                        # 找出 box 内点组成的最长线段
+                        for p1 in box_points:
+                            for p2 in box_points:
+                                if p1 != p2:
+                                    distance = np.linalg.norm(np.array(p1) - np.array(p2))
+                                    if distance > max_distance:
+                                        max_distance = distance
+                                        longest_segment = [p1, p2]
+
+                        tmp = longest_segment
+                        score_max = 0.0  # 默认分数为 0.0
+
+                line_.append(tmp)
+                score_.append(score_max)
+            processed_list = torch.tensor(line_)
+            pred[idx]['line'] = processed_list
+
+            processed_s_list = torch.tensor(score_)
+            pred[idx]['line_score'] = processed_s_list
+    return pred
+
+def show_predict1(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    im = imgs.permute(1, 2, 0)  # ´¦ÀíΪ [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # ¿ÉÊÓ»¯Ô¤²â½á
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # ¿òÖÐÎÞÏßµÄÌø¹ý
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0 or line_score >= 0:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+    plt.savefig("temp_result.png")
+    plt.show()
+    # output_path = "temp_result.png"
+    # save_plot(output_path)
+    # return output_path
+
+
+def predict(image_path):
+
+    start_time = time.time()
+
+    model_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
+    model = load_model(model_path)
+
+    img_tensor,_ = preprocess_image(image_path)
+    print(f'img shape:{img_tensor.shape}')
+
+    # Ä£ÐÍÍÆÀí
+    with torch.no_grad():
+      predictions = model([img_tensor.to(device)])
+    # print(f'predictions[0]:{predictions[0]}')
+    # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
+    # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328])
+
+    start_time1 = time.time()
+    show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
+    show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions)
+
+    H = predictions[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * np.array([2112, 1328])
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (512 ** 2 + 512 ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+
+    # t_start = time.time()
+    # filtered_pred = box_line_optimized_parallel(img_tensor,predictions)
+    # # print(f'匹配后:{filtered_pred}')
+    # print(f'匹配后 len:{filtered_pred[0]}')
+    # t_end = time.time()
+    # print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+    # show_predict(img_tensor.permute(1, 2, 0).cpu().numpy(), filtered_pred, start_time1)
+
+
+    # # ºÏ²¢Í¼Ïñ
+    # combined_image_path = "combined_result.png"
+    # combine_images(
+    #     [output_path_boxandline, output_path_box, output_path_line],
+    #     titles=["Box and Line", "Box", "Line"],
+    #     output_path=combined_image_path
+    # )
+
+    # end_time = time.time()
+    # print(f'Total time: {end_time - start_time:.2f} seconds')
+
+    # lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328])
+    print(f'线段 len:{len(nlines)}')
+    # print(f"Initial lines shape: {lines.shape}")
+    # print(f"Initial lines data type: {lines.dtype}")
+
+    formatted_lines = []
+    for line in nlines:
+        if (line == [[0.0, 0.0], [0.0, 0.0]]).all():
+            continue
+            #line=[[500.0, 500.0], [650.0, 650.0]]
+        start_point = np.array([line[0][0], line[0][1]])
+        end_point = np.array([line[1][0], line[1][1]])
+        formatted_lines.append([start_point, end_point])
+
+    formatted_lines = np.array(formatted_lines)
+    print(f"Final formatted_lines shape: {formatted_lines.shape}")
+    print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
+
+     # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array]
+    result = [formatted_lines]
+    print(f"Final result type: {type(result)}")
+    print(f"Final result[0] shape: {result[0].shape}")
+
+    return result
+
+
+
+def show_line(im, predictions, start_time1):
+    """»æÖÆÏ߶β¢±£´æ½á¹û"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+
+    is_all_zeros = np.all(lines == 0.0)
+    if is_all_zeros:
+        fig, ax = plt.subplots(figsize=(10, 10))
+        t_end = time.time()
+        plt.savefig("temp_line.png")
+    else:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        fig, ax = plt.subplots(figsize=(10, 10))
+        ax.imshow(im)
+        for (a, b), s in zip(nlines, nscores):
+            if s < 0:
+                continue
+            ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        t_end = time.time()
+        plt.savefig("temp_line.png")
+    print(f'show line time:{t_end-start_time1}')
+
+
+    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    # fig, ax = plt.subplots(figsize=(10, 10))
+    # ax.imshow(im)
+    # for (a, b), s in zip(nlines, nscores):
+    #     if s < 0:
+    #         continue
+    #     ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+    #     ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    # t_end = time.time()
+    # plt.savefig("temp_line.png")
+
+def show_box(im, predictions):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    plt.savefig("temp_box.png")
+
+def show_predict(im, filtered_pred, t_start):
+    """»æÖÆÆ¥ÅäºóµÄ±ß½ç¿òºÍÏ߶β¢±£´æ½á¹û"""
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    pred = filtered_pred[0]
+
+    boxes = pred['boxes'].cpu().numpy()
+    box_scores = pred['scores'].cpu().numpy()
+    lines = pred['line'].cpu().numpy()
+    line_scores = pred['line_score'].cpu().numpy()
+    print("Boxes:", pred['boxes'])
+    print("Lines:", pred['line'])
+    print("Line scores:", pred['line_score'])
+
+    is_all_zeros = np.all(lines == 0.0)
+    if not is_all_zeros:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+            if box_score < 0.0 or line_score < 0.0:
+                continue
+
+            # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
+            if line is None or len(line) == 0:
+                continue
+
+            x0, y0, x1, y1 = box
+            a, b = line
+            color = colors[box_idx % len(colors)]  # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c=color, s=10)
+            ax.scatter(b[1], b[0], c=color, s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+
+        # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+        #     if box_score < 0.0 or line_score < 0.0:
+        #         continue
+        #
+        #     # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
+        #     if line is None or len(line) == 0:
+        #         continue
+        #
+        #     x0, y0, x1, y1 = box
+        #     a, b = line
+        #     color = colors[(idx + box_idx) % len(colors)]  # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
+        #     ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+        #     ax.scatter(a[1], a[0], c=color, s=10)
+        #     ax.scatter(b[1], b[0], c=color, s=10)
+        #     ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    # plt.show()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+if __name__ == "__main__":
+    lines = predict(r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png")
+    print(f'lines:{lines}')

+ 22 - 1
models/line_detect/line_net.py

@@ -1,3 +1,4 @@
+import os
 from typing import Any, Callable, List, Optional, Tuple, Union
 import torch
 from torch import nn
@@ -26,7 +27,7 @@ from ..base import backbone_factory
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
-from .predict import Predict
+from .predict import Predict1, Predict
 
 from ..config.config_tool import read_yaml
 
@@ -213,10 +214,30 @@ class LineNet(BaseDetectionNet):
         self.trainer = Trainer()
         self.trainer.train_cfg(model=self, cfg=cfg)
 
+    def load_best_model(self,model,  save_path, device='cuda'):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            # if optimizer is not None:
+            #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            # epoch = checkpoint['epoch']
+            # loss = checkpoint['loss']
+            # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+            print(f"Loaded model from {save_path}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model
+
+    # 加载权重和推理一起
     def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
         self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
         self.predict.run()
 
+    # 不加载权重
+    def predict1(self, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
+        self.predict = Predict1(model, img_path, type, threshold, save_path, show)
+        self.predict.run()
+
 
 class TwoMLPHead(nn.Module):
     """

+ 102 - 2
models/line_detect/predict.py

@@ -41,7 +41,7 @@ def box_line_(imgs, pred):  # 默认置信度
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     line, score = postprocess(lines, scores, diag * 0.01, 0, False)
     # print(f'333:{len(lines)}')
-    for idx, box_ in enumerate(pred[0:-1]):
+    for idx, box_ in enumerate(pred[0:-2]):
         box = box_['boxes']  # 是一个tensor
 
         line_ = []
@@ -330,7 +330,7 @@ class Predict:
         # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
         if im.shape != (512, 512, 3):
             im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
-        img_ = torch.tensor(im_resized).permute(2, 0, 1)  # [3, 512, 512]
+        img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
         t_end = time.time()
         print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
 
@@ -369,3 +369,103 @@ class Predict:
     def run(self):
         """运行预测流程"""
         self.predict()
+
+class Predict1:
+    def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
+        """
+        初始化预测器。
+
+        参数:
+            pt_path: 模型权重文件路径。
+            model: 模型定义(未加载权重)。
+            img: 输入图像(路径或 PIL 图像对象)。
+            type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
+            threshold: 阈值,用于过滤预测结果。
+            save_path: 保存结果的路径(可选)。
+            show: 是否显示结果。
+            device: 运行设备(默认 'cuda')。
+        """
+        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
+        self.model = model
+        self.img = self.load_image(img)
+        self.type = type
+        self.threshold = threshold
+        self.save_path = save_path
+        self.show_line = show_line
+        self.show_box = show_box
+
+    def load_best_model(self, model, save_path, device):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            # if optimizer is not None:
+            #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            # epoch = checkpoint['epoch']
+            # loss = checkpoint['loss']
+            # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model
+
+    def load_image(self, img):
+        """加载图像"""
+        if isinstance(img, str):
+            img = Image.open(img).convert("RGB")
+        return img
+
+    def preprocess_image(self, img):
+        """预处理图像"""
+        transform = transforms.ToTensor()
+        img_tensor = transform(img)  # [3, H, W]
+
+        # 调整大小为 512x512
+        t_start = time.time()
+        im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
+        # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+        if im.shape != (512, 512, 3):
+            im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+        img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
+        t_end = time.time()
+        print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
+
+        return img_
+
+    def predict(self):
+        """执行预测"""
+        # model = self.load_best_model(self.model, self.pt_path, device)
+        model = self.model
+
+        model.eval()
+
+        # 预处理图像
+        img_ = self.preprocess_image(self.img)
+
+        # 模型推理
+        with torch.no_grad():
+            predictions = model([img_.to(self.device)])
+            print("Model predictions completed.")
+
+
+        # 根据类型显示或保存结果
+        if self.type == 0:
+            # 后处理
+            t_start = time.time()
+            pred = box_line_(img_, predictions)  # 线框匹配
+            t_end = time.time()
+            print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
+            show_all(img_, pred, self.threshold, save_path=self.save_path)
+        elif self.type == 1:
+            show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
+        elif self.type == 2:
+            show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
+        elif self.type == 3:
+            # 后处理
+            t_start = time.time()
+            pred = box_line_(img_, predictions)  # 线框匹配
+            t_end = time.time()
+            print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
+            show_predict(img_, pred, self.threshold, t_start)
+
+    def run(self):
+        """运行预测流程"""
+        self.predict()

+ 10 - 10
models/line_detect/predict2.py

@@ -437,8 +437,8 @@ def predict(pt_path, model, img):
 
     im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
     if im.shape != (512, 512, 3):
-        im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
-    img_ = torch.tensor(im_resized).permute(2, 0, 1)  # [3, 512, 512]
+        im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+    img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
 
     t_end = time.time()
     print(f'switch img used:{t_end - t_start}')
@@ -456,12 +456,12 @@ def predict(pt_path, model, img):
     # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
     # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
 
-    t_start = time.time()
-    pred = box_line1(img_, predictions)
-    t_end = time.time()
-    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
-
-    show_predict(img_, pred, t_start)
+    # t_start = time.time()
+    # pred = box_line1(img_, predictions)
+    # t_end = time.time()
+    # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+    #
+    # show_predict(img_, pred, t_start)
 
 
 if __name__ == '__main__':
@@ -470,9 +470,9 @@ if __name__ == '__main__':
     model = linenet_resnet50_fpn().to(device)
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
+    pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
-    img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
+    img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')