Browse Source

修改model.predict

xue50 7 months ago
parent
commit
14a3221d39

+ 1 - 0
.gitignore

@@ -1,6 +1,7 @@
 .idea
 .idea
 *.pt
 *.pt
 *.pth
 *.pth
+*.png
 
 
 *.log
 *.log
 *.onnx
 *.onnx

+ 11 - 11
models/line_detect/111.py

@@ -48,6 +48,7 @@ class Trainer(BaseTrainer):
                  **kwargs):
                  **kwargs):
         super().__init__(model, dataset, device, **kwargs)
         super().__init__(model, dataset, device, **kwargs)
         self.freeze_config = freeze_config or {}  # 默认冻结配置为空
         self.freeze_config = freeze_config or {}  # 默认冻结配置为空
+
     def move_to_device(self, data, device):
     def move_to_device(self, data, device):
         if isinstance(data, (list, tuple)):
         if isinstance(data, (list, tuple)):
             return type(data)(self.move_to_device(item, device) for item in data)
             return type(data)(self.move_to_device(item, device) for item in data)
@@ -57,11 +58,12 @@ class Trainer(BaseTrainer):
             return data.to(device)
             return data.to(device)
         else:
         else:
             return data  # 对于非张量类型的数据不做任何改变
             return data  # 对于非张量类型的数据不做任何改变
+
     def freeze_params(self, model):
     def freeze_params(self, model):
         """根据配置冻结模型参数"""
         """根据配置冻结模型参数"""
         default_config = {
         default_config = {
             'backbone': True,  # 冻结 backbone
             'backbone': True,  # 冻结 backbone
-            'rpn': False,      # 不冻结 rpn
+            'rpn': False,  # 不冻结 rpn
             'roi_heads': {
             'roi_heads': {
                 'box_head': False,
                 'box_head': False,
                 'box_predictor': False,
                 'box_predictor': False,
@@ -226,11 +228,11 @@ class Trainer(BaseTrainer):
 import torch
 import torch
 
 
 from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
 from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
+
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 if __name__ == '__main__':
-
     # model = LineNet('line_net.yaml')
     # model = LineNet('line_net.yaml')
-    model=linenet_resnet50_fpn().to(device)
+    model = linenet_resnet50_fpn().to(device)
     # model=linenet_resnet18_fpn()
     # model=linenet_resnet18_fpn()
     # trainer = Trainer()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     # trainer.train_cfg(model,cfg='./train.yaml')
@@ -238,12 +240,10 @@ if __name__ == '__main__':
     # trainer = Trainer()
     # trainer = Trainer()
     # trainer.train_cfg(model=model, cfg='train.yaml')
     # trainer.train_cfg(model=model, cfg='train.yaml')
     #
     #
-    pt_path = r"E:\projects\tmp\MultiVisionModels\models\line_detect\train_results\20250424_162124\weights\best.pth"
-    img_path = r"I:\datasets\4_23jiagonggongjian\images\val\2025-04-23-08-52-00_SaveRightImage.png"
-
-    model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
-
-
-
-
+    # pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
+    # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
+    # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
 
 
+    model = model.load_best_model(model, r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth")
+    img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
+    model.predict1(model, img_path, type=1, threshold=0, save_path=None, show=True)

+ 474 - 0
models/line_detect/aaa.py

@@ -0,0 +1,474 @@
+# from fastapi import FastAPI, File, UploadFile, HTTPException
+# from fastapi.responses import FileResponse
+# from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.wirenet.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+
+# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_detect.boxline import show_box
+
+# ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """¼ÓÔØÄ£ÐͲ¢·µ»ØÄ£ÐÍʵÀý"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """Ô¤´¦ÀíÉÏ´«µÄͼƬ"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1),img
+
+def save_plot(output_path: str):
+    """±£´æÍ¼Ïñ²¢¹Ø±Õ»æÍ¼"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+def get_colors():
+    """·µ»ØÒ»×éÔ¤¶¨ÒåµÄÑÕÉ«Áбí"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def process_box(box, lines, scores):
+    """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
+    valid_lines = []  # ´æ´¢ÓÐЧµÄÏß¶Î
+    valid_scores = []  # ´æ´¢ÓÐЧµÄ·ÖÊý
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # ±éÀúËùÓÐÏß¶Î
+        for j in range(lines.shape[1]):
+            # line_j = lines[0, j].cpu().numpy() / 128 * 512
+            line_j = lines[0, j].cpu().numpy()
+            # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # ¼ÆËãÏ߶㤶È
+                length = np.linalg.norm(line_j[0] - line_j[1])
+                # length = scores[j].cpu().numpy()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        # Èç¹ûÕÒµ½ÓÐЧµÄÏ߶Σ¬ÔòÌí¼Óµ½½á¹ûÖÐ
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)  # ʹÓÃÏ߶㤶È×÷Ϊ·ÖÊý
+
+        else:
+            valid_lines.append([[0.0,0.0],[0.0,0.0]])
+            valid_scores.append(0.0)  # ʹÓÃÏß¶ÎÖÃÐŶÈ×÷Ϊ·ÖÊý
+        # print(f'valid_lines:{valid_lines}')
+        # print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+
+# def box_line_optimized_parallel(imgs, pred):  # ĬÈÏÖÃÐŶÈ
+#     im = imgs.permute(1, 2, 0).cpu().numpy()
+#     line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+#     line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+#
+#     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+#     line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+#     for idx, box_ in enumerate(pred[0:-1]):
+#         box = box_['boxes']
+#
+#         line_ = []
+#         score_ = []
+#
+#         for i in box:
+#             score_max = 0.0
+#             tmp = [[0.0, 0.0], [0.0, 0.0]]
+#
+#             for j in range(len(line)):
+#                 if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+#                         line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+#                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+#                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+#
+#                     if score[j] > score_max:
+#                         tmp = line[j]
+#                         score_max = score[j]
+#             line_.append(tmp)
+#             score_.append(score_max)
+#         processed_list = torch.tensor(line_)
+#         pred[idx]['line'] = processed_list
+#
+#         processed_s_list = torch.tensor(score_)
+#         pred[idx]['line_score'] = processed_s_list
+#     del pred[-1]
+#     return pred
+
+def box_line_optimized_parallel(imgs, pred, length=False):    # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+    # print(f'line_data:{line_data}')
+
+    points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
+
+    is_all_zeros = np.all(line_data == 0.0)
+    if is_all_zeros:
+        for idx, box_ in enumerate(pred[0:-1]):
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+            processed_list = torch.tensor(tmp)
+            pred[idx]['line'] = processed_list
+
+            processed_s_list = torch.tensor(score_max)
+            pred[idx]['line_score'] = processed_s_list
+    else:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+        for idx, box_ in enumerate(pred[0:-2]):
+            box = box_['boxes']  # 是一个tensor
+            line_ = []
+            score_ = []
+            for i in box:
+                score_max = 0.0
+                tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+                for j in range(len(line)):
+                    if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                            line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                            line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                            line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                        if score[j] > score_max:
+                            tmp = line[j]
+                            score_max = score[j]
+                # 如果 box 内无线段,则通过点坐标找最长线段
+                if score_max == 0.0:  # 说明 box 内无线段
+                    box_points = [
+                        [x, y] for x, y in points
+                        if i[0] <= y <= i[2] and i[1] <= x <= i[3]
+                    ]
+
+                    if len(box_points) >= 2:  # 至少需要两个点才能组成线段
+                        max_distance = 0.0
+                        longest_segment = [[0.0, 0.0], [0.0, 0.0]]
+
+                        # 找出 box 内点组成的最长线段
+                        for p1 in box_points:
+                            for p2 in box_points:
+                                if p1 != p2:
+                                    distance = np.linalg.norm(np.array(p1) - np.array(p2))
+                                    if distance > max_distance:
+                                        max_distance = distance
+                                        longest_segment = [p1, p2]
+
+                        tmp = longest_segment
+                        score_max = 0.0  # 默认分数为 0.0
+
+                line_.append(tmp)
+                score_.append(score_max)
+            processed_list = torch.tensor(line_)
+            pred[idx]['line'] = processed_list
+
+            processed_s_list = torch.tensor(score_)
+            pred[idx]['line_score'] = processed_s_list
+    return pred
+
+def show_predict1(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    im = imgs.permute(1, 2, 0)  # ´¦ÀíΪ [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # ¿ÉÊÓ»¯Ô¤²â½á
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # ¿òÖÐÎÞÏßµÄÌø¹ý
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0 or line_score >= 0:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+    plt.savefig("temp_result.png")
+    plt.show()
+    # output_path = "temp_result.png"
+    # save_plot(output_path)
+    # return output_path
+
+
+def predict(image_path):
+
+    start_time = time.time()
+
+    model_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
+    model = load_model(model_path)
+
+    img_tensor,_ = preprocess_image(image_path)
+    print(f'img shape:{img_tensor.shape}')
+
+    # Ä£ÐÍÍÆÀí
+    with torch.no_grad():
+      predictions = model([img_tensor.to(device)])
+    # print(f'predictions[0]:{predictions[0]}')
+    # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
+    # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328])
+
+    start_time1 = time.time()
+    show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
+    show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions)
+
+    H = predictions[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * np.array([2112, 1328])
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (512 ** 2 + 512 ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+
+    # t_start = time.time()
+    # filtered_pred = box_line_optimized_parallel(img_tensor,predictions)
+    # # print(f'匹配后:{filtered_pred}')
+    # print(f'匹配后 len:{filtered_pred[0]}')
+    # t_end = time.time()
+    # print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+    # show_predict(img_tensor.permute(1, 2, 0).cpu().numpy(), filtered_pred, start_time1)
+
+
+    # # ºÏ²¢Í¼Ïñ
+    # combined_image_path = "combined_result.png"
+    # combine_images(
+    #     [output_path_boxandline, output_path_box, output_path_line],
+    #     titles=["Box and Line", "Box", "Line"],
+    #     output_path=combined_image_path
+    # )
+
+    # end_time = time.time()
+    # print(f'Total time: {end_time - start_time:.2f} seconds')
+
+    # lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328])
+    print(f'线段 len:{len(nlines)}')
+    # print(f"Initial lines shape: {lines.shape}")
+    # print(f"Initial lines data type: {lines.dtype}")
+
+    formatted_lines = []
+    for line in nlines:
+        if (line == [[0.0, 0.0], [0.0, 0.0]]).all():
+            continue
+            #line=[[500.0, 500.0], [650.0, 650.0]]
+        start_point = np.array([line[0][0], line[0][1]])
+        end_point = np.array([line[1][0], line[1][1]])
+        formatted_lines.append([start_point, end_point])
+
+    formatted_lines = np.array(formatted_lines)
+    print(f"Final formatted_lines shape: {formatted_lines.shape}")
+    print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
+
+     # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array]
+    result = [formatted_lines]
+    print(f"Final result type: {type(result)}")
+    print(f"Final result[0] shape: {result[0].shape}")
+
+    return result
+
+
+
+def show_line(im, predictions, start_time1):
+    """»æÖÆÏ߶β¢±£´æ½á¹û"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+
+    is_all_zeros = np.all(lines == 0.0)
+    if is_all_zeros:
+        fig, ax = plt.subplots(figsize=(10, 10))
+        t_end = time.time()
+        plt.savefig("temp_line.png")
+    else:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        fig, ax = plt.subplots(figsize=(10, 10))
+        ax.imshow(im)
+        for (a, b), s in zip(nlines, nscores):
+            if s < 0:
+                continue
+            ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        t_end = time.time()
+        plt.savefig("temp_line.png")
+    print(f'show line time:{t_end-start_time1}')
+
+
+    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    # fig, ax = plt.subplots(figsize=(10, 10))
+    # ax.imshow(im)
+    # for (a, b), s in zip(nlines, nscores):
+    #     if s < 0:
+    #         continue
+    #     ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+    #     ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    # t_end = time.time()
+    # plt.savefig("temp_line.png")
+
+def show_box(im, predictions):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    plt.savefig("temp_box.png")
+
+def show_predict(im, filtered_pred, t_start):
+    """»æÖÆÆ¥ÅäºóµÄ±ß½ç¿òºÍÏ߶β¢±£´æ½á¹û"""
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    pred = filtered_pred[0]
+
+    boxes = pred['boxes'].cpu().numpy()
+    box_scores = pred['scores'].cpu().numpy()
+    lines = pred['line'].cpu().numpy()
+    line_scores = pred['line_score'].cpu().numpy()
+    print("Boxes:", pred['boxes'])
+    print("Lines:", pred['line'])
+    print("Line scores:", pred['line_score'])
+
+    is_all_zeros = np.all(lines == 0.0)
+    if not is_all_zeros:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+            if box_score < 0.0 or line_score < 0.0:
+                continue
+
+            # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
+            if line is None or len(line) == 0:
+                continue
+
+            x0, y0, x1, y1 = box
+            a, b = line
+            color = colors[box_idx % len(colors)]  # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c=color, s=10)
+            ax.scatter(b[1], b[0], c=color, s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+
+        # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+        #     if box_score < 0.0 or line_score < 0.0:
+        #         continue
+        #
+        #     # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
+        #     if line is None or len(line) == 0:
+        #         continue
+        #
+        #     x0, y0, x1, y1 = box
+        #     a, b = line
+        #     color = colors[(idx + box_idx) % len(colors)]  # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
+        #     ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+        #     ax.scatter(a[1], a[0], c=color, s=10)
+        #     ax.scatter(b[1], b[0], c=color, s=10)
+        #     ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    # plt.show()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+if __name__ == "__main__":
+    lines = predict(r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png")
+    print(f'lines:{lines}')

+ 22 - 1
models/line_detect/line_net.py

@@ -1,3 +1,4 @@
+import os
 from typing import Any, Callable, List, Optional, Tuple, Union
 from typing import Any, Callable, List, Optional, Tuple, Union
 import torch
 import torch
 from torch import nn
 from torch import nn
@@ -26,7 +27,7 @@ from ..base import backbone_factory
 from ..base.base_detection_net import BaseDetectionNet
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 import torch.nn.functional as F
 
 
-from .predict import Predict
+from .predict import Predict1, Predict
 
 
 from ..config.config_tool import read_yaml
 from ..config.config_tool import read_yaml
 
 
@@ -213,10 +214,30 @@ class LineNet(BaseDetectionNet):
         self.trainer = Trainer()
         self.trainer = Trainer()
         self.trainer.train_cfg(model=self, cfg=cfg)
         self.trainer.train_cfg(model=self, cfg=cfg)
 
 
+    def load_best_model(self,model,  save_path, device='cuda'):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            # if optimizer is not None:
+            #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            # epoch = checkpoint['epoch']
+            # loss = checkpoint['loss']
+            # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+            print(f"Loaded model from {save_path}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model
+
+    # 加载权重和推理一起
     def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
     def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
         self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
         self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
         self.predict.run()
         self.predict.run()
 
 
+    # 不加载权重
+    def predict1(self, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
+        self.predict = Predict1(model, img_path, type, threshold, save_path, show)
+        self.predict.run()
+
 
 
 class TwoMLPHead(nn.Module):
 class TwoMLPHead(nn.Module):
     """
     """

+ 102 - 2
models/line_detect/predict.py

@@ -41,7 +41,7 @@ def box_line_(imgs, pred):  # 默认置信度
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     line, score = postprocess(lines, scores, diag * 0.01, 0, False)
     line, score = postprocess(lines, scores, diag * 0.01, 0, False)
     # print(f'333:{len(lines)}')
     # print(f'333:{len(lines)}')
-    for idx, box_ in enumerate(pred[0:-1]):
+    for idx, box_ in enumerate(pred[0:-2]):
         box = box_['boxes']  # 是一个tensor
         box = box_['boxes']  # 是一个tensor
 
 
         line_ = []
         line_ = []
@@ -330,7 +330,7 @@ class Predict:
         # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
         # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
         if im.shape != (512, 512, 3):
         if im.shape != (512, 512, 3):
             im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
             im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
-        img_ = torch.tensor(im_resized).permute(2, 0, 1)  # [3, 512, 512]
+        img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
         t_end = time.time()
         t_end = time.time()
         print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
         print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
 
 
@@ -369,3 +369,103 @@ class Predict:
     def run(self):
     def run(self):
         """运行预测流程"""
         """运行预测流程"""
         self.predict()
         self.predict()
+
+class Predict1:
+    def __init__(self, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
+        """
+        初始化预测器。
+
+        参数:
+            pt_path: 模型权重文件路径。
+            model: 模型定义(未加载权重)。
+            img: 输入图像(路径或 PIL 图像对象)。
+            type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
+            threshold: 阈值,用于过滤预测结果。
+            save_path: 保存结果的路径(可选)。
+            show: 是否显示结果。
+            device: 运行设备(默认 'cuda')。
+        """
+        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
+        self.model = model
+        self.img = self.load_image(img)
+        self.type = type
+        self.threshold = threshold
+        self.save_path = save_path
+        self.show_line = show_line
+        self.show_box = show_box
+
+    def load_best_model(self, model, save_path, device):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            # if optimizer is not None:
+            #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            # epoch = checkpoint['epoch']
+            # loss = checkpoint['loss']
+            # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model
+
+    def load_image(self, img):
+        """加载图像"""
+        if isinstance(img, str):
+            img = Image.open(img).convert("RGB")
+        return img
+
+    def preprocess_image(self, img):
+        """预处理图像"""
+        transform = transforms.ToTensor()
+        img_tensor = transform(img)  # [3, H, W]
+
+        # 调整大小为 512x512
+        t_start = time.time()
+        im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
+        # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+        if im.shape != (512, 512, 3):
+            im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+        img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
+        t_end = time.time()
+        print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
+
+        return img_
+
+    def predict(self):
+        """执行预测"""
+        # model = self.load_best_model(self.model, self.pt_path, device)
+        model = self.model
+
+        model.eval()
+
+        # 预处理图像
+        img_ = self.preprocess_image(self.img)
+
+        # 模型推理
+        with torch.no_grad():
+            predictions = model([img_.to(self.device)])
+            print("Model predictions completed.")
+
+
+        # 根据类型显示或保存结果
+        if self.type == 0:
+            # 后处理
+            t_start = time.time()
+            pred = box_line_(img_, predictions)  # 线框匹配
+            t_end = time.time()
+            print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
+            show_all(img_, pred, self.threshold, save_path=self.save_path)
+        elif self.type == 1:
+            show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
+        elif self.type == 2:
+            show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
+        elif self.type == 3:
+            # 后处理
+            t_start = time.time()
+            pred = box_line_(img_, predictions)  # 线框匹配
+            t_end = time.time()
+            print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
+            show_predict(img_, pred, self.threshold, t_start)
+
+    def run(self):
+        """运行预测流程"""
+        self.predict()

+ 10 - 10
models/line_detect/predict2.py

@@ -437,8 +437,8 @@ def predict(pt_path, model, img):
 
 
     im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
     im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
     if im.shape != (512, 512, 3):
     if im.shape != (512, 512, 3):
-        im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
-    img_ = torch.tensor(im_resized).permute(2, 0, 1)  # [3, 512, 512]
+        im = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+    img_ = torch.tensor(im).permute(2, 0, 1)  # [3, 512, 512]
 
 
     t_end = time.time()
     t_end = time.time()
     print(f'switch img used:{t_end - t_start}')
     print(f'switch img used:{t_end - t_start}')
@@ -456,12 +456,12 @@ def predict(pt_path, model, img):
     # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
     # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
     # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
     # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
 
 
-    t_start = time.time()
-    pred = box_line1(img_, predictions)
-    t_end = time.time()
-    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
-
-    show_predict(img_, pred, t_start)
+    # t_start = time.time()
+    # pred = box_line1(img_, predictions)
+    # t_end = time.time()
+    # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+    #
+    # show_predict(img_, pred, t_start)
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
@@ -470,9 +470,9 @@ if __name__ == '__main__':
     model = linenet_resnet50_fpn().to(device)
     model = linenet_resnet50_fpn().to(device)
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
+    pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
-    img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
+    img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
     predict(pt_path, model, img_path)
     predict(pt_path, model, img_path)
     t_end = time.time()
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')
     print(f'predict used:{t_end - t_start}')