소스 검색

修改box问题,回退到之前版本

xue50 7 달 전
부모
커밋
78acc468a5
3개의 변경된 파일724개의 추가작업 그리고 303개의 파일을 삭제
  1. 190 283
      models/dataset_tool.py
  2. 11 20
      models/line_detect/dataset_LD.py
  3. 523 0
      models/line_detect/infer.py

+ 190 - 283
models/dataset_tool.py

@@ -1,234 +1,40 @@
-import cv2
-import numpy as np
-import torch
-import torchvision
-from matplotlib import pyplot as plt
-import tools.transforms as reference_transforms
-from collections import defaultdict
+# ??roi_head??????????????
+from torch.utils.data.dataset import T_co
 
-from tools import presets
+from models.base.base_dataset import BaseDataset
 
+import glob
 import json
+import math
+import os
+import random
+import cv2
+import PIL
 
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+from torchvision.utils import draw_bounding_boxes
 
-def get_modules(use_v2):
-    # We need a protected import to avoid the V2 warning in case just V1 is used
-    if use_v2:
-        import torchvision.transforms.v2
-        import torchvision.tv_tensors
-
-        return torchvision.transforms.v2, torchvision.tv_tensors
-    else:
-        return reference_transforms, None
-
-
-class Augmentation:
-    # Note: this transform assumes that the input to forward() are always PIL
-    # images, regardless of the backend parameter.
-    def __init__(
-            self,
-            *,
-            data_augmentation,
-            hflip_prob=0.5,
-            mean=(123.0, 117.0, 104.0),
-            backend="pil",
-            use_v2=False,
-    ):
-
-        T, tv_tensors = get_modules(use_v2)
-
-        transforms = []
-        backend = backend.lower()
-        if backend == "tv_tensor":
-            transforms.append(T.ToImage())
-        elif backend == "tensor":
-            transforms.append(T.PILToTensor())
-        elif backend != "pil":
-            raise ValueError(f"backend can be 'tv_tensor', 'tensor' or 'pil', but got {backend}")
-
-        if data_augmentation == "hflip":
-            transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
-        elif data_augmentation == "lsj":
-            transforms += [
-                T.ScaleJitter(target_size=(1024, 1024), antialias=True),
-                # TODO: FixedSizeCrop below doesn't work on tensors!
-                reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "multiscale":
-            transforms += [
-                T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssd":
-            fill = defaultdict(lambda: mean, {tv_tensors.Mask: 0}) if use_v2 else list(mean)
-            transforms += [
-                T.RandomPhotometricDistort(),
-                T.RandomZoomOut(fill=fill),
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        elif data_augmentation == "ssdlite":
-            transforms += [
-                T.RandomIoUCrop(),
-                T.RandomHorizontalFlip(p=hflip_prob),
-            ]
-        else:
-            raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
-
-        if backend == "pil":
-            # Note: we could just convert to pure tensors even in v2.
-            transforms += [T.ToImage() if use_v2 else T.PILToTensor()]
-
-        transforms += [T.ToDtype(torch.float, scale=True)]
-
-        if use_v2:
-            transforms += [
-                T.ConvertBoundingBoxFormat(tv_tensors.BoundingBoxFormat.XYXY),
-                T.SanitizeBoundingBoxes(),
-                T.ToPureTensor(),
-            ]
-
-        self.transforms = T.Compose(transforms)
-
-    def __call__(self, img, target):
-        return self.transforms(img, target)
-
-
-def read_polygon_points(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = f.readlines()
-
-    for line in lines:
-        parts = line.strip().split()
-        class_id = int(parts[0])
-        points = np.array(parts[1:], dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_pixels(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = reader.readlines()
-        mask_points = []
-        for line in lines:
-            mask = torch.zeros((h, w), dtype=torch.uint8)
-            parts = line.strip().split()
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(parts[0]), dtype=torch.int64)
-            labels.append(cls)
-            x_array = parts[1::2]
-            y_array = parts[2::2]
-
-            for x, y in zip(x_array, y_array):
-                x = float(x)
-                y = float(y)
-                mask_points.append((int(y * h), int(x * w)))
-
-            for p in mask_points:
-                mask[p] = 1
-            masks.append(mask)
-    reader.close()
-    return labels, masks
-
-
-def create_masks_from_polygons(polygons, image_shape):
-    """创建一个与图像尺寸相同的掩码,并填充多边形轮廓"""
-    colors = np.array([plt.cm.Spectral(each) for each in np.linspace(0, 1, len(polygons))])
-    masks = []
-
-    for polygon_data, col in zip(polygons, colors):
-        mask = np.zeros(image_shape[:2], dtype=np.uint8)
-        # 将多边形顶点转换为 NumPy 数组
-        _, polygon = polygon_data
-        pts = np.array(polygon, np.int32).reshape((-1, 1, 2))
-
-        # 使用 OpenCV 的 fillPoly 函数填充多边形
-        # print(f'color:{col[:3]}')
-        cv2.fillPoly(mask, [pts], np.array(col[:3]) * 255)
-        mask = torch.from_numpy(mask)
-        mask[mask != 0] = 1
-        masks.append(mask)
-
-    return masks
-
-
-def read_masks_from_txt(label_path, shape):
-    polygon_points = read_polygon_points(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def masks_to_boxes(masks: torch.Tensor, ) -> torch.Tensor:
-    """
-    Compute the bounding boxes around the provided masks.
-
-    Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
-    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
-
-    Args:
-        masks (Tensor[N, H, W]): masks to transform where N is the number of masks
-            and (H, W) are the spatial dimensions.
-
-    Returns:
-        Tensor[N, 4]: bounding boxes
-    """
-    # if not torch.jit.is_scripting() and not torch.jit.is_tracing():
-    #     _log_api_usage_once(masks_to_boxes)
-    if masks.numel() == 0:
-        return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
-
-    n = masks.shape[0]
-
-    bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
-
-    for index, mask in enumerate(masks):
-        y, x = torch.where(mask != 0)
-        bounding_boxes[index, 0] = torch.min(x)
-        bounding_boxes[index, 1] = torch.min(y)
-        bounding_boxes[index, 2] = torch.max(x)
-        bounding_boxes[index, 3] = torch.max(y)
-        # debug to pixel datasets
-
-        if bounding_boxes[index, 0] == bounding_boxes[index, 2]:
-            bounding_boxes[index, 2] = bounding_boxes[index, 2] + 1
-            bounding_boxes[index, 0] = bounding_boxes[index, 0] - 1
+import numpy as np
+import numpy.linalg as LA
+import torch
+from skimage import io
+from torch.utils.data import Dataset
+from torch.utils.data.dataloader import default_collate
 
-        if bounding_boxes[index, 1] == bounding_boxes[index, 3]:
-            bounding_boxes[index, 3] = bounding_boxes[index, 3] + 1
-            bounding_boxes[index, 1] = bounding_boxes[index, 1] - 1
+import matplotlib.pyplot as plt
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
 
-    return bounding_boxes
+from tools.presets import DetectionPresetTrain
 
-
-def line_boxes(target):
+def line_boxes1(target):
     boxs = []
-    lpre = target['wires']["lpre"].cpu().numpy() * 4
-    vecl_target = target['wires']["lpre_label"].cpu().numpy()
-    lpre = lpre[vecl_target == 1]
-
-    lines = lpre
-    sline = np.ones(lpre.shape[0])
+    lines = target.cpu().numpy() * 4
 
     if len(lines) > 0 and not (lines[0] == 0).all():
-        for i, ((a, b), s) in enumerate(zip(lines, sline)):
+        for i, ((a, b)) in enumerate(lines):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
-            # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
 
             if a[1] > b[1]:
                 ymax = a[1] + 10
@@ -244,71 +50,172 @@ def line_boxes(target):
                 xmax = b[0] + 10
             boxs.append([ymin, xmin, ymax, xmax])
 
+    # if boxs == []:
+    #     print(target)
+
     return torch.tensor(boxs)
 
 
-def read_polygon_points_wire(lbl_path, shape):
-    """读取 YOLOv8 格式的标注文件并解析多边形轮廓"""
-    polygon_points = []
-    w, h = shape[:2]
-    with open(lbl_path, 'r') as f:
-        lines = json.load(f)
-
-    for line in lines["segmentations"]:
-        parts = line["data"]
-        class_id = int(line["cls_id"])
-        points = np.array(parts, dtype=np.float32).reshape(-1, 2)  # 读取点坐标
-        points[:, 0] *= h
-        points[:, 1] *= w
-
-        polygon_points.append((class_id, points))
-
-    return polygon_points
-
-
-def read_masks_from_txt_wire(label_path, shape):
-    polygon_points = read_polygon_points_wire(label_path, shape)
-    masks = create_masks_from_polygons(polygon_points, shape)
-
-    labels = [torch.tensor(item[0]) for item in polygon_points]
-
-    return labels, masks
-
-
-def read_masks_from_pixels_wire(lbl_path, shape):
-    """读取纯像素点格式的文件,不是轮廓像素点"""
-    h, w = shape
-    masks = []
-    labels = []
-
-    with open(lbl_path, 'r') as reader:
-        lines = json.load(reader)
-        mask_points = []
-        for line in lines["segmentations"]:
-            # mask = torch.zeros((h, w), dtype=torch.uint8)
-            # parts = line["data"]
-            # print(f'parts:{parts}')
-            cls = torch.tensor(int(line["cls_id"]), dtype=torch.int64)
-            labels.append(cls)
-            # x_array = parts[0::2]
-            # y_array = parts[1::2]
-            # 
-            # for x, y in zip(x_array, y_array):
-            #     x = float(x)
-            #     y = float(y)
-            #     mask_points.append((int(y * h), int(x * w)))
-
-            # for p in mask_points:
-            #     mask[p] = 1
-            # masks.append(mask)
-    reader.close()
-    return labels
-
-
-def adjacency_matrix(n, link):  # 邻接矩阵
-    mat = torch.zeros(n + 1, n + 1, dtype=torch.uint8)
-    link = torch.tensor(link)
-    if len(link) > 0:
-        mat[link[:, 0], link[:, 1]] = 1
-        mat[link[:, 1], link[:, 0]] = 1
-    return mat
+class WirePointDataset(BaseDataset):
+    def __init__(self, dataset_path, transforms=None, dataset_type=None, target_type='pixel'):
+        super().__init__(dataset_path)
+
+        self.data_path = dataset_path
+        print(f'data_path:{dataset_path}')
+        self.transforms = transforms
+        self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
+        self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
+        self.imgs = os.listdir(self.img_path)
+        self.lbls = os.listdir(self.lbl_path)
+        self.target_type = target_type
+        # self.default_transform = DefaultTransform()
+
+    def __getitem__(self, index) -> T_co:
+        img_path = os.path.join(self.img_path, self.imgs[index])
+        lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
+
+        img = PIL.Image.open(img_path).convert('RGB')
+        w, h = img.size
+
+        # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        if self.transforms:
+            img, target = self.transforms(img, target)
+        else:
+            img = self.default_transform(img)
+
+        # print(f'img:{img.shape}')
+        return img, target
+
+    def __len__(self):
+        return len(self.imgs)
+
+    def read_target(self, item, lbl_path, shape, extra=None):
+        # print(f'lbl_path:{lbl_path}')
+        with open(lbl_path, 'r') as file:
+            lable_all = json.load(file)
+
+        n_stc_posl = 300
+        n_stc_negl = 40
+        use_cood = 0
+        use_slop = 0
+
+        wire = lable_all["wires"][0]  # ??
+        line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl]  # ?????????
+        line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
+        npos, nneg = len(line_pos_coords), len(line_neg_coords)
+        lpre = np.concatenate([line_pos_coords, line_neg_coords], 0)  # ??????????
+        for i in range(len(lpre)):
+            if random.random() > 0.5:
+                lpre[i] = lpre[i, ::-1]
+        ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
+        ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
+        feat = [
+            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            ldir * use_slop,
+            lpre[:, :, 2],
+        ]
+        feat = np.concatenate(feat, 1)
+
+        wire_labels = {
+            "junc_coords": torch.tensor(wire["junc_coords"]["content"])[:, :2],
+            "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
+            "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
+            # ???????????
+            "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
+            # ??????????
+            "lpre": torch.tensor(lpre)[:, :, :2],
+            "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]),  # ?????? 1?0
+            "lpre_feat": torch.from_numpy(feat),
+            "junc_map": torch.tensor(wire['junc_map']["content"]),
+            "junc_offset": torch.tensor(wire['junc_offset']["content"]),
+            "line_map": torch.tensor(wire['line_map']["content"]),
+        }
+
+        labels = []
+        #
+        # if self.target_type == 'polygon':
+        #     labels, masks = read_masks_from_txt_wire(lbl_path, shape)
+        # elif self.target_type == 'pixel':
+        #     labels = read_masks_from_pixels_wire(lbl_path, shape)
+
+        # print(torch.stack(masks).shape)    # [???, 512, 512]
+        target = {}
+        # target["labels"] = torch.stack(labels)
+
+
+        target["image_id"] = torch.tensor(item)
+        # return wire_labels, target
+        target["wires"] = wire_labels
+        # target["boxes"] = line_boxes(target)
+        target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
+        target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
+        # print(f'target["labels"]:{ target["labels"]}')
+        # print(f'boxes:{target["boxes"].shape}')
+        if target["boxes"].numel() == 0:
+            print("Tensor is empty")
+            print(f'path:{lbl_path}')
+        return target
+
+    def show(self, idx):
+        image, target = self.__getitem__(idx)
+
+        cmap = plt.get_cmap("jet")
+        norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+        sm.set_array([])
+
+        def imshow(im):
+            plt.close()
+            plt.tight_layout()
+            plt.imshow(im)
+            plt.colorbar(sm, fraction=0.046)
+            plt.xlim([0, im.shape[0]])
+            plt.ylim([im.shape[0], 0])
+
+        def draw_vecl(lines, sline, juncs, junts, fn=None):
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            imshow(io.imread(img_path))
+            if len(lines) > 0 and not (lines[0] == 0).all():
+                for i, ((a, b), s) in enumerate(zip(lines, sline)):
+                    if i > 0 and (lines[i] == lines[0]).all():
+                        break
+                    plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]?????
+            if not (juncs[0] == 0).all():
+                for i, j in enumerate(juncs):
+                    if i > 0 and (i == juncs[0]).all():
+                        break
+                    plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # ? s=64
+
+
+            img_path = os.path.join(self.img_path, self.imgs[idx])
+            img = PIL.Image.open(img_path).convert('RGB')
+            boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
+                                              colors="yellow", width=1)
+            plt.imshow(boxed_image.permute(1, 2, 0).numpy())
+            plt.show()
+
+            plt.show()
+            if fn != None:
+                plt.savefig(fn)
+
+        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        jtyp = target['wires']['jtyp'].cpu().numpy()
+        juncs = junc[jtyp == 0]
+        junts = junc[jtyp == 1]
+
+        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        vecl_target = target['wires']["lpre_label"].cpu().numpy()
+        lpre = lpre[vecl_target == 1]
+
+        # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
+        draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
+
+
+    def show_img(self, img_path):
+        pass
+
+
+# dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
+# for i in dataset_train:
+#     a = 1

+ 11 - 20
models/line_detect/dataset_LD.py

@@ -27,20 +27,17 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
 
 from tools.presets import DetectionPresetTrain
 
-
 def line_boxes1(target):
     boxs = []
     lines = target.cpu().numpy() * 4
 
-    if len(lines) > 0 :
+    if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b)) in enumerate(lines):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]ÎÞÃ÷È·´óС
-            # if a[-1]==0. and b[-1]==0.:
-            #     continue
-            # if a[:2].tolist() == [0., 0.] and b[:2].tolist() == [0., 0.]:
-            #     continue
+            if a[-1]==0. and b[-1]==0.:
+                continue
 
             if a[1] > b[1]:
                 ymax = a[1] + 10
@@ -54,12 +51,10 @@ def line_boxes1(target):
             else:
                 xmin = a[0] - 10
                 xmax = b[0] + 10
-            boxs.append([min(ymin,0), min(xmin,0), max(ymax,512), max(xmax, 0)])
+            boxs.append([ymin, xmin, ymax, xmax])
 
-    # print(f'box:{boxs}')
     # if boxs == []:
-    #     print(f'box:{boxs}')
-    #     print(f'target:{target}')
+    #     print(target)
 
     return torch.tensor(boxs)
 
@@ -151,15 +146,14 @@ class WirePointDataset(BaseDataset):
         target = {}
         # target["labels"] = torch.stack(labels)
 
+
         target["image_id"] = torch.tensor(item)
         # return wire_labels, target
         target["wires"] = wire_labels
-        # target["boxes"] = line_boxes(target)
-        target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
-        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
+        target["boxes"] = line_boxes(target)
+        # target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
+        target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
         # print(f'target["labels"]:{ target["labels"]}')
-        # if target["boxes"].shape == [0]:
-        #     print(f'box is null:{lbl_path}')
         # print(f'boxes:{target["boxes"].shape}')
         return target
 
@@ -193,6 +187,7 @@ class WirePointDataset(BaseDataset):
                         break
                     plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
 
+
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
             boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
@@ -216,10 +211,6 @@ class WirePointDataset(BaseDataset):
         # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
         draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
+
     def show_img(self, img_path):
         pass
-
-
-# dataset_train = WirePointDataset(r"\\192.168.50.222\share\lm\04\424-转分好的zjf", dataset_type='train')
-# for i in dataset_train:
-#     a = 1

+ 523 - 0
models/line_detect/infer.py

@@ -0,0 +1,523 @@
+# from fastapi import FastAPI, File, UploadFile, HTTPException
+# from fastapi.responses import FileResponse
+# from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.wirenet.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+
+# from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_detect.boxline import show_box
+
+# ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def load_model(model_path):
+    """¼ÓÔØÄ£ÐͲ¢·µ»ØÄ£ÐÍʵÀý"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+
+def preprocess_image(image_path):
+    """Ô¤´¦ÀíÉÏ´«µÄͼƬ"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1), img
+
+
+def save_plot(output_path: str):
+    """±£´æÍ¼Ïñ²¢¹Ø±Õ»æÍ¼"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+
+def get_colors():
+    """·µ»ØÒ»×éÔ¤¶¨ÒåµÄÑÕÉ«Áбí"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+
+def process_box(box, lines, scores):
+    """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
+    valid_lines = []  # ´æ´¢ÓÐЧµÄÏß¶Î
+    valid_scores = []  # ´æ´¢ÓÐЧµÄ·ÖÊý
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # ±éÀúËùÓÐÏß¶Î
+        for j in range(lines.shape[1]):
+            # line_j = lines[0, j].cpu().numpy() / 128 * 512
+            line_j = lines[0, j].cpu().numpy()
+            # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # ¼ÆËãÏ߶㤶È
+                length = np.linalg.norm(line_j[0] - line_j[1])
+                # length = scores[j].cpu().numpy()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        # Èç¹ûÕÒµ½ÓÐЧµÄÏ߶Σ¬ÔòÌí¼Óµ½½á¹ûÖÐ
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)  # ʹÓÃÏ߶㤶È×÷Ϊ·ÖÊý
+
+        else:
+            valid_lines.append([[0.0, 0.0], [0.0, 0.0]])
+            valid_scores.append(0.0)  # ʹÓÃÏß¶ÎÖÃÐŶÈ×÷Ϊ·ÖÊý
+        # print(f'valid_lines:{valid_lines}')
+        # print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+
+# def box_line_optimized_parallel(imgs, pred):  # ĬÈÏÖÃÐŶÈ
+#     im = imgs.permute(1, 2, 0).cpu().numpy()
+#     line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+#     line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+#
+#     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+#     line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+#     for idx, box_ in enumerate(pred[0:-1]):
+#         box = box_['boxes']
+#
+#         line_ = []
+#         score_ = []
+#
+#         for i in box:
+#             score_max = 0.0
+#             tmp = [[0.0, 0.0], [0.0, 0.0]]
+#
+#             for j in range(len(line)):
+#                 if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+#                         line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+#                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+#                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+#
+#                     if score[j] > score_max:
+#                         tmp = line[j]
+#                         score_max = score[j]
+#             line_.append(tmp)
+#             score_.append(score_max)
+#         processed_list = torch.tensor(line_)
+#         pred[idx]['line'] = processed_list
+#
+#         processed_s_list = torch.tensor(score_)
+#         pred[idx]['line_score'] = processed_s_list
+#     del pred[-1]
+#     return pred
+
+def box_line_optimized_parallel(imgs, pred, length=False):  # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+    # print(f'line_data:{line_data}')
+
+    points = pred[-1]['wires']['juncs'].cpu().numpy()[0] / 128 * 512
+
+    is_all_zeros = np.all(line_data == 0.0)
+    if is_all_zeros:
+        for idx, box_ in enumerate(pred[0:-2]):
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+            processed_list = torch.tensor(tmp)
+            pred[idx]['line'] = processed_list
+
+            processed_s_list = torch.tensor(score_max)
+            pred[idx]['line_score'] = processed_s_list
+    else:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+        for idx, box_ in enumerate(pred[0:-2]):
+            box = box_['boxes']  # 是一个tensor
+            line_ = []
+            score_ = []
+            for i in box:
+                score_max = 0.0
+                tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+                for j in range(len(line)):
+                    if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                            line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                            line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                            line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                        if score[j] > score_max:
+                            tmp = line[j]
+                            score_max = score[j]
+                # 如果 box 内无线段,则通过点坐标找最长线段
+                if score_max == 0.0:  # 说明 box 内无线段
+                    box_points = [
+                        [x, y] for x, y in points
+                        if i[0] <= y <= i[2] and i[1] <= x <= i[3]
+                    ]
+
+                    if len(box_points) >= 2:  # 至少需要两个点才能组成线段
+                        max_distance = 0.0
+                        longest_segment = [[0.0, 0.0], [0.0, 0.0]]
+
+                        # 找出 box 内点组成的最长线段
+                        for p1 in box_points:
+                            for p2 in box_points:
+                                if p1 != p2:
+                                    distance = np.linalg.norm(np.array(p1) - np.array(p2))
+                                    if distance > max_distance:
+                                        max_distance = distance
+                                        longest_segment = [p1, p2]
+
+                        tmp = longest_segment
+                        score_max = 0.0  # 默认分数为 0.0
+
+                line_.append(tmp)
+                score_.append(score_max)
+            processed_list = torch.tensor(line_)
+            pred[idx]['line'] = processed_list
+
+            processed_s_list = torch.tensor(score_)
+            pred[idx]['line_score'] = processed_s_list
+    return pred
+
+
+def show_predict1(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    im = imgs.permute(1, 2, 0)  # ´¦ÀíΪ [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # ¿ÉÊÓ»¯Ô¤²â½á
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # ¿òÖÐÎÞÏßµÄÌø¹ý
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0 or line_score >= 0:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+    plt.savefig("temp_result.png")
+    plt.show()
+    # output_path = "temp_result.png"
+    # save_plot(output_path)
+    # return output_path
+
+
+# def predict(image_path):
+#
+#     start_time = time.time()
+#
+#     model_path = r'/home/zqy/docker_lstznkj_1.0/code/ubuntu2204/VisionWeldRobotMK/lcnn_yolo/best.pth'
+#     model = load_model(model_path)
+#
+#     img_tensor,_ = preprocess_image(image_path)
+#
+#     # Ä£ÐÍÍÆÀí
+#     with torch.no_grad():
+#       predictions = model([img_tensor.to(device)])
+#     # print(f'predictions[0]:{predictions[0]}')
+#     # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
+#
+#     start_time1 = time.time()
+#     show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
+#     show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions)
+#
+#
+#     t_start = time.time()
+#     filtered_pred = box_line_optimized_parallel(img_tensor,predictions)
+#     # print(f'匹配后:{filtered_pred}')
+#     print(f'匹配后 len:{filtered_pred[0]}')
+#     t_end = time.time()
+#     print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+#     show_predict(img_tensor.permute(1, 2, 0).cpu().numpy(), filtered_pred, start_time1)
+#
+#
+#     # # ºÏ²¢Í¼Ïñ
+#     # combined_image_path = "combined_result.png"
+#     # combine_images(
+#     #     [output_path_boxandline, output_path_box, output_path_line],
+#     #     titles=["Box and Line", "Box", "Line"],
+#     #     output_path=combined_image_path
+#     # )
+#
+#     end_time = time.time()
+#     print(f'Total time: {end_time - start_time:.2f} seconds')
+#
+#     lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328])
+#     print(f'线段{lines}')
+#     # print(f"Initial lines shape: {lines.shape}")
+#     # print(f"Initial lines data type: {lines.dtype}")
+#
+#     formatted_lines = []
+#     for line in lines:
+#         if (line == [[0.0, 0.0], [0.0, 0.0]]).all():
+#             continue
+#             #line=[[500.0, 500.0], [650.0, 650.0]]
+#         start_point = np.array([line[0][0], line[0][1]])
+#         end_point = np.array([line[1][0], line[1][1]])
+#         formatted_lines.append([start_point, end_point])
+#
+#     formatted_lines = np.array(formatted_lines)
+#     print(f"Final formatted_lines shape: {formatted_lines.shape}")
+#     print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
+#
+#      # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array]
+#     result = [formatted_lines]
+#     print(f"Final result type: {type(result)}")
+#     print(f"Final result[0] shape: {result[0].shape}")
+#
+#     return result
+def predict(image_path):
+    start_time = time.time()
+
+    model_path = r"\\192.168.50.222\share\rlq\weights\best0425.pth"
+    model = load_model(model_path)
+
+    img_tensor, _ = preprocess_image(image_path)
+    print(f'img shape:{img_tensor.shape}')
+
+    # Ä£ÐÍÍÆÀí
+    with torch.no_grad():
+        predictions = model([img_tensor.to(device)])
+    # print(f'predictions[0]:{predictions[0]}')
+    # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
+    # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328])
+
+    start_time1 = time.time()
+    show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
+    show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions)
+
+    H = predictions[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * np.array([2112, 1328])
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # postprocess lines to remove overlapped lines
+    diag = (512 ** 2 + 512 ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    print(f'线段 len:{len(nlines)}')
+    # print(f"Initial lines shape: {lines.shape}")
+    # print(f"Initial lines data type: {lines.dtype}")
+
+    formatted_lines = []
+    for line, score in zip(nlines, nscores):
+        if (line == [[0.0, 0.0], [0.0, 0.0]]).all() or score <= 0.7:
+            continue
+            # line=[[500.0, 500.0], [650.0, 650.0]]
+        start_point = np.array([line[0][0], line[0][1]])
+        end_point = np.array([line[1][0], line[1][1]])
+        formatted_lines.append([start_point, end_point])
+
+    formatted_lines = np.array(formatted_lines)
+    print(f"Final formatted_lines shape: {formatted_lines.shape}")
+    # print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
+
+    # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array]
+    result = [formatted_lines]
+    print(f"Final result type: {type(result)}")
+    print(f"Final result[0] shape: {result[0].shape}")
+
+    return result
+
+
+def show_line(im, predictions, start_time1):
+    """»æÖÆÏ߶β¢±£´æ½á¹û"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+
+    is_all_zeros = np.all(lines == 0.0)
+    if is_all_zeros:
+        fig, ax = plt.subplots(figsize=(10, 10))
+        t_end = time.time()
+        plt.savefig("/home/zqy/docker_lstznkj_1.0/code/ubuntu2204/VisionWeldRobotMK/temp_line.png")
+    else:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        fig, ax = plt.subplots(figsize=(10, 10))
+        ax.imshow(im)
+        for (a, b), s in zip(nlines, nscores):
+            if s < 0:
+                continue
+            ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        t_end = time.time()
+        plt.savefig("/home/zqy/docker_lstznkj_1.0/code/ubuntu2204/VisionWeldRobotMK/temp_line.png")
+    print(f'show line time:{t_end - start_time1}')
+
+    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    # fig, ax = plt.subplots(figsize=(10, 10))
+    # ax.imshow(im)
+    # for (a, b), s in zip(nlines, nscores):
+    #     if s < 0:
+    #         continue
+    #     ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+    #     ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    # t_end = time.time()
+    # plt.savefig("temp_line.png")
+
+
+def show_box(im, predictions):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    plt.savefig("/home/zqy/docker_lstznkj_1.0/code/ubuntu2204/VisionWeldRobotMK/temp_box.png")
+
+
+def show_predict(im, filtered_pred, t_start):
+    """»æÖÆÆ¥ÅäºóµÄ±ß½ç¿òºÍÏ߶β¢±£´æ½á¹û"""
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    pred = filtered_pred[0]
+
+    boxes = pred['boxes'].cpu().numpy()
+    box_scores = pred['scores'].cpu().numpy()
+    lines = pred['line'].cpu().numpy()
+    line_scores = pred['line_score'].cpu().numpy()
+    print("Boxes:", pred['boxes'])
+    print("Lines:", pred['line'])
+    print("Line scores:", pred['line_score'])
+
+    is_all_zeros = np.all(lines == 0.0)
+    if not is_all_zeros:
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+            if box_score < 0.0 or line_score < 0.0:
+                continue
+
+            # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
+            if line is None or len(line) == 0:
+                continue
+
+            x0, y0, x1, y1 = box
+            a, b = line
+            color = colors[box_idx % len(colors)]  # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c=color, s=10)
+            ax.scatter(b[1], b[0], c=color, s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+
+        # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+        #     if box_score < 0.0 or line_score < 0.0:
+        #         continue
+        #
+        #     # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
+        #     if line is None or len(line) == 0:
+        #         continue
+        #
+        #     x0, y0, x1, y1 = box
+        #     a, b = line
+        #     color = colors[(idx + box_idx) % len(colors)]  # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
+        #     ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+        #     ax.scatter(a[1], a[0], c=color, s=10)
+        #     ax.scatter(b[1], b[0], c=color, s=10)
+        #     ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    # plt.show()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "/home/zqy/docker_lstznkj_1.0/code/ubuntu2204/VisionWeldRobotMK/temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+
+if __name__ == "__main__":
+    lines = predict(r"C:\Users\m2337\Desktop\p\140502.png")
+    print(f'lines:{lines}')