Sfoglia il codice sorgente

model.predict is OK

xue50 8 mesi fa
parent
commit
fb2d28ca65

+ 1 - 0
models/base/base_detection_net.py

@@ -11,6 +11,7 @@ from torch import nn, Tensor
 
 from libs.vision_libs.utils import _log_api_usage_once
 from models.base.base_model import BaseModel
+import matplotlib.pyplot as plt
 
 
 class BaseDetectionNet(BaseModel):

+ 11 - 3
models/line_detect/111.py

@@ -230,10 +230,18 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
-    model=linenet_resnet50_fpn()
+    model=linenet_resnet50_fpn().to(device)
     #model=linenet_resnet18_fpn()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     # model.train_by_cfg(cfg='train.yaml')
-    trainer = Trainer()
-    trainer.train_cfg(model=model, cfg='train.yaml')
+    # trainer = Trainer()
+    # trainer.train_cfg(model=model, cfg='train.yaml')
+    pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
+    img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
+    model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
+
+
+
+
+

+ 1 - 1
models/line_detect/dataset_LD.py

@@ -87,7 +87,7 @@ class WirePointDataset(BaseDataset):
         else:
             img = self.default_transform(img)
 
-        # print(f'img:{img}')
+        # print(f'img:{img.shape}')
         return img, target
 
     def __len__(self):

+ 7 - 2
models/line_detect/line_net.py

@@ -26,6 +26,8 @@ from ..base import backbone_factory
 from ..base.base_detection_net import BaseDetectionNet
 import torch.nn.functional as F
 
+from .predict import Predict
+
 from ..config.config_tool import read_yaml
 
 FEATURE_DIM = 8
@@ -211,6 +213,10 @@ class LineNet(BaseDetectionNet):
         self.trainer = Trainer()
         self.trainer.train_cfg(model=self, cfg=cfg)
 
+    def predict(self, pt_path, model, img_path, type=0, threshold=0.5, save_path=None, show=False):
+        self.predict = Predict(pt_path, model, img_path, type, threshold, save_path, show)
+        self.predict.run()
+
 
 class TwoMLPHead(nn.Module):
     """
@@ -396,7 +402,6 @@ class LineNet_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
     weights=("pretrained", LineNet_ResNet50_FPN_Weights.COCO_V1),
     weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
 )
-
 def linenet_resnet18_fpn(
         *,
         weights: Optional[LineNet_ResNet50_FPN_Weights] = None,
@@ -406,7 +411,6 @@ def linenet_resnet18_fpn(
         trainable_backbone_layers: Optional[int] = None,
         **kwargs: Any,
 ) -> LineNet:
-
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
 
@@ -433,6 +437,7 @@ def linenet_resnet18_fpn(
 
     return model
 
+
 def linenet_resnet50_fpn(
         *,
         weights: Optional[LineNet_ResNet50_FPN_Weights] = None,

+ 325 - 84
models/line_detect/predict.py

@@ -1,3 +1,14 @@
+# import time
+# import torch
+# from PIL import Image
+# from torchvision import transforms
+# from skimage.transform import resize
+
+import time
+
+import cv2
+import skimage
+
 import os
 
 import torch
@@ -5,126 +16,356 @@ from PIL import Image
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 import numpy as np
-from models.line_detect.line_net import linenet_resnet50_fpn
+# from models.line_detect.line_net import linenet_resnet50_fpn
 from torchvision import transforms
 
 from models.wirenet.postprocess import postprocess
+from rtree import index
+
+from datetime import datetime
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
-def load_best_model(model, save_path, device):
-    if os.path.exists(save_path):
-        checkpoint = torch.load(save_path, map_location=device)
-        model.load_state_dict(checkpoint['model_state_dict'])
-        # if optimizer is not None:
-        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-        epoch = checkpoint['epoch']
-        loss = checkpoint['loss']
-        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+def box_line_(imgs, pred):  # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    # print(f'111:{len(lines)}')
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, score = postprocess(lines, scores, diag * 0.01, 0, False)
+    # print(f'333:{len(lines)}')
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+            score_.append(score_max)
+        processed_list = torch.tensor(np.array(line_))
+        pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
+    return pred
+
+
+def set_thresholds(threshold):
+    if isinstance(threshold, list):
+        if len(threshold) != 2:
+            raise ValueError("Threshold list must contain exactly two elements.")
+        a, b = threshold
+    elif isinstance(threshold, (int, float)):
+        a = b = threshold
     else:
-        print(f"No saved model found at {save_path}")
-    return model
+        raise TypeError("Threshold must be either a list of two numbers or a single number.")
+
+    return a, b
+
+
+def color():
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
 
 
-cmap = plt.get_cmap("jet")
-norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
-sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
-sm.set_array([])
+def show_all(imgs, pred, threshold, save_path):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)
+
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
 
+    fig, axs = plt.subplots(1, 3, figsize=(10, 10))
 
-def c(x):
-    return sm.to_rgba(x)
+    axs[0].imshow(np.array(im))
+    for idx, box in enumerate(boxes):
+        if box_scores[idx] < box_th:
+            continue
+        x0, y0, x1, y1 = box
+        axs[0].add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+    axs[0].set_title('Boxes')
 
+    axs[1].imshow(np.array(im))
+    for idx, (a, b) in enumerate(line):
+        if line_score[idx] < line_th:
+            continue
+        axs[1].scatter(a[1], a[0], c='#871F78', s=2)
+        axs[1].scatter(b[1], b[0], c='#871F78', s=2)
+        axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    axs[1].set_title('Lines')
 
-def imshow(im):
-    plt.close()
+    axs[2].imshow(np.array(im))
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+    idx = 0
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0.7 or line_score >= 0.9:
+            axs[2].add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            axs[2].scatter(a[1], a[0], c='#871F78', s=10)
+            axs[2].scatter(b[1], b[0], c='#871F78', s=10)
+            axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    axs[2].set_title('Boxes and Lines')
+
+    if save_path:
+        save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+        plt.savefig(save_path)
+        print(f"Saved result image to {save_path}")
+
+    # if show:
+    # 调整子图之间的距离,防止标题和标签重叠
     plt.tight_layout()
-    plt.imshow(im)
-    plt.colorbar(sm, fraction=0.046)
-    plt.xlim([0, im.shape[0]])
-    plt.ylim([im.shape[0], 0])
+    plt.show()
 
 
-def show_line(img, pred):
-    im = img.permute(1, 2, 0)
+def show_box_or_line(imgs, pred, threshold, save_path=None, show_line=False, show_box=False):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)
 
-    # 创建图形和坐标轴
+    # 可视化预测结
     fig, ax = plt.subplots(figsize=(10, 10))
-    # 绘制原始图像
     ax.imshow(np.array(im))
-    # 绘制边界框
-    boxes = pred[0]['boxes'].cpu().numpy()
-    boxes_scores = pred[0]['scores'].cpu().numpy()
 
-    # for box in boxes:
-    #     x0, y0, x1, y1 = box
-    #     rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
-    #     ax.add_patch(rect)  # 将矩形添加到 Axes 对象上
+    if show_box:
+        boxes = pred[0]['boxes'].cpu().numpy()
+        box_scores = pred[0]['scores'].cpu().numpy()
+        for idx, box in enumerate(boxes):
+            if box_scores[idx] < box_th:
+                continue
+            x0, y0, x1, y1 = box
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+        if save_path:
+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+            plt.savefig(save_path)
+            print(f"Saved result image to {save_path}")
+
+    if show_line:
+        lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+        scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+        for i in range(1, len(lines)):
+            if (lines[i] == lines[0]).all():
+                lines = lines[:i]
+                scores = scores[:i]
+                break
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
+
+        for idx, (a, b) in enumerate(line):
+            if line_score[idx] < line_th:
+                continue
+            ax.scatter(a[1], a[0], c='#871F78', s=2)
+            ax.scatter(b[1], b[0], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        if save_path:
+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+            plt.savefig(save_path)
+            print(f"Saved result image to {save_path}")
+
+    plt.show()
+
+
+def show_predict(imgs, pred, threshold, t_start):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)  # 处理为 [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    scores = pred[0]['line_score'].cpu().numpy()
 
-    for b, s in zip(boxes, boxes_scores):
-        # print(f'box:{b}, box_score:{s}')
-        if s < 0.7:
-            continue
-        x0, y0, x1, y1 = b
-        rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
-        ax.add_patch(rect)  # 将矩形添加到 Axes 对象上
-
-    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
-    H = pred[-1]['wires']
-    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
-    scores = H["score"][0].cpu().numpy()
     for i in range(1, len(lines)):
         if (lines[i] == lines[0]).all():
             lines = lines[:i]
             scores = scores[:i]
             break
-
-    # 后处理线条以去除重叠的线条
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+    line1, line_score1 = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, line1, box_scores, line_score1):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= box_th or line_score >= line_th:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
 
-    # 根据分数绘制线条
-    for i, t in enumerate([0.9]):
-        for (a, b), s in zip(nlines, nscores):
-            if s < t:
-                continue
-            ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)  # 在 Axes 上绘制线条
-            ax.scatter(a[1], a[0], **PLTOPTS)  # 在 Axes 上绘制散点
-            ax.scatter(b[1], b[0], **PLTOPTS)  # 在 Axes 上绘制散点
-
-    # 隐藏坐标轴
-    ax.set_axis_off()
-    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-    plt.margins(0, 0)
-    ax.xaxis.set_major_locator(plt.NullLocator())
-    ax.yaxis.set_major_locator(plt.NullLocator())
-
-    # 显示图像
     plt.show()
 
 
-def predict(pt_path, model, img):
-    model = load_best_model(model, pt_path, device)
+class Predict:
+    def __init__(self, pt_path, model, img, type=0, threshold=0.5, save_path=None, show_line=False, show_box=False):
+        """
+        初始化预测器。
+
+        参数:
+            pt_path: 模型权重文件路径。
+            model: 模型定义(未加载权重)。
+            img: 输入图像(路径或 PIL 图像对象)。
+            type: 预测类型(0: 全部显示,线图、框图、线框匹配图,1: 显示线图,2: 显示框图,3: 线框匹配图)。
+            threshold: 阈值,用于过滤预测结果。
+            save_path: 保存结果的路径(可选)。
+            show: 是否显示结果。
+            device: 运行设备(默认 'cuda')。
+        """
+        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
+        self.model = model
+        self.pt_path = pt_path
+        self.img = self.load_image(img)
+        self.type = type
+        self.threshold = threshold
+        self.save_path = save_path
+        self.show_line = show_line
+        self.show_box = show_box
+
+    def load_best_model(self, model, save_path, device):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            # if optimizer is not None:
+            #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            # epoch = checkpoint['epoch']
+            # loss = checkpoint['loss']
+            # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model
+
+    def load_image(self, img):
+        """加载图像"""
+        if isinstance(img, str):
+            img = Image.open(img).convert("RGB")
+        return img
+
+    def preprocess_image(self, img):
+        """预处理图像"""
+        transform = transforms.ToTensor()
+        img_tensor = transform(img)  # [3, H, W]
+
+        # 调整大小为 512x512
+        t_start = time.time()
+        im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
+        # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+        if im.shape != (512, 512, 3):
+            im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+        img_ = torch.tensor(im_resized).permute(2, 0, 1)  # [3, 512, 512]
+        t_end = time.time()
+        print(f"Image preprocessing used: {t_end - t_start:.4f} seconds")
+
+        return img_
 
-    model.eval()
+    def predict(self):
+        """执行预测"""
+        model = self.load_best_model(self.model, self.pt_path, device)
 
-    if isinstance(img, str):
-        img = Image.open(img).convert("RGB")
+        model.eval()
 
-    transform = transforms.ToTensor()
-    img_tensor = transform(img)
+        # 预处理图像
+        img_ = self.preprocess_image(self.img)
 
-    with torch.no_grad():
-        predictions = model([img_tensor])
-        print(predictions[0])
+        # 模型推理
+        with torch.no_grad():
+            predictions = model([img_.to(self.device)])
+            print("Model predictions completed.")
 
-    show_line(img_tensor, predictions)
+        # 后处理
+        t_start = time.time()
+        pred = box_line_(img_, predictions)  # 线框匹配
+        t_end = time.time()
+        print(f"Matched boxes and lines used: {t_end - t_start:.4f} seconds")
 
+        # 根据类型显示或保存结果
+        if self.type == 0:
+            show_all(img_, pred, self.threshold, save_path=self.save_path)
+        elif self.type == 1:
+            show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_line=True)
+        elif self.type == 2:
+            show_box_or_line(img_, predictions, self.threshold, save_path=self.save_path, show_box=True)
+        elif self.type == 3:
+            show_predict(img_, pred, self.threshold, t_start)
 
-if __name__ == '__main__':
-    model = linenet_resnet50_fpn()
-    pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
-    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
-    img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
-    predict(pt_path, model, img_path)
+    def run(self):
+        """运行预测流程"""
+        self.predict()

+ 41 - 29
models/line_detect/predict2.py

@@ -155,6 +155,7 @@
 
 import time
 
+import cv2
 import skimage
 
 import os
@@ -249,28 +250,28 @@ def box_line1(imgs, pred):  # 默认置信度
                     if score[j] > score_max:
                         tmp = line[j]
                         score_max = score[j]
-            # 如果 box 内无线段,则通过点坐标找最长线段
-            if score_max == 0.0:  # 说明 box 内无线段
-                box_points = [
-                    [x, y] for x, y in points
-                    if i[0] <= y <= i[2] and i[1] <= x <= i[3]
-                ]
-
-                if len(box_points) >= 2:  # 至少需要两个点才能组成线段
-                    max_distance = 0.0
-                    longest_segment = [[0.0, 0.0], [0.0, 0.0]]
-
-                    # 找出 box 内点组成的最长线段
-                    for p1 in box_points:
-                        for p2 in box_points:
-                            if p1 != p2:
-                                distance = np.linalg.norm(np.array(p1) - np.array(p2))
-                                if distance > max_distance:
-                                    max_distance = distance
-                                    longest_segment = [p1, p2]
-
-                    tmp = longest_segment
-                    score_max = 0.0  # 默认分数为 0.0
+            # # 如果 box 内无线段,则通过点坐标找最长线段
+            # if score_max == 0.0:  # 说明 box 内无线段
+            #     box_points = [
+            #         [x, y] for x, y in points
+            #         if i[0] <= y <= i[2] and i[1] <= x <= i[3]
+            #     ]
+            #
+            #     if len(box_points) >= 2:  # 至少需要两个点才能组成线段
+            #         max_distance = 0.0
+            #         longest_segment = [[0.0, 0.0], [0.0, 0.0]]
+            #
+            #         # 找出 box 内点组成的最长线段
+            #         for p1 in box_points:
+            #             for p2 in box_points:
+            #                 if p1 != p2:
+            #                     distance = np.linalg.norm(np.array(p1) - np.array(p2))
+            #                     if distance > max_distance:
+            #                         max_distance = distance
+            #                         longest_segment = [p1, p2]
+            #
+            #         tmp = longest_segment
+            #         score_max = 0.0  # 默认分数为 0.0
 
             line_.append(tmp)
             score_.append(score_max)
@@ -383,14 +384,19 @@ def show_predict(imgs, pred, t_start):
 
 def show_line(imgs, pred, t_start):
     im = imgs.permute(1, 2, 0)
-    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    lines = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     # print(pred[-1]['wires']['score'])
-    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
+    scores = pred[-1]['wires']['score'].cpu().numpy()[0]
 
     t1 = time.time()
     print(f't1:{t1 - t_start}')
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
+    line, line_score = postprocess(lines, scores, diag * 0.01, 0, False)
     print(f'lines num:{len(line)}')
     t2 = time.time()
     print(f't1:{t2 - t1}')
@@ -425,9 +431,15 @@ def predict(pt_path, model, img):
 
     # 将图像调整为512x512大小
     t_start = time.time()
-    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
-    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
-    img_ = torch.tensor(im_resized).permute(2, 0, 1)
+    # im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    # im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    # img_ = torch.tensor(im_resized).permute(2, 0, 1)
+
+    im = img_tensor.permute(1, 2, 0)  # [H, W, 3]
+    if im.shape != (512, 512, 3):
+        im_resized = cv2.resize(im.cpu().numpy().astype(np.float32), (512, 512), interpolation=cv2.INTER_LINEAR)
+    img_ = torch.tensor(im_resized).permute(2, 0, 1)  # [3, 512, 512]
+
     t_end = time.time()
     print(f'switch img used:{t_end - t_start}')
 
@@ -458,7 +470,7 @@ if __name__ == '__main__':
     model = linenet_resnet50_fpn().to(device)
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    pt_path = r"C:\Users\m2337\Downloads\best_e150.pth"
+    pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
     img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
     predict(pt_path, model, img_path)