Browse Source

main_lm_0223.py是部署在8001端口 并行计算

xue50 3 months ago
parent
commit
4b13aea0a2

+ 313 - 0
models/line_detect/boxline.py

@@ -0,0 +1,313 @@
+# from fastapi import FastAPI, File, UploadFile, HTTPException
+# from fastapi.responses import FileResponse
+# from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.wirenet.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+# from fastapi.middleware.cors import CORSMiddleware
+# 初始化 FastAPI
+# app = FastAPI()
+
+# 添加 CORS 中间件
+# app.add_middleware(
+#     CORSMiddleware,
+#     allow_origins=["*"],  # 允许所有源
+#     allow_credentials=True,
+#     allow_methods=["*"],
+#     allow_headers=["*"],
+# )
+
+# 设置多进程启动方式为 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """加载模型并返回模型实例"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """预处理上传的图片"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1),img
+
+def save_plot(output_path: str):
+    """保存图像并关闭绘图"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+def get_colors():
+    """返回一组预定义的颜色列表"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def process_box(box, lines, scores):
+    """处理单个边界框,找到最佳匹配的线段"""
+    valid_lines = []  # 存储有效的线段
+    valid_scores = []  # 存储有效的分数
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # 遍历所有线段
+        for j in range(lines.shape[1]):
+            # line_j = lines[0, j].cpu().numpy() / 128 * 512
+            line_j = lines[0, j].cpu().numpy()
+            # 检查线段是否完全在box内
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # 计算线段长度
+                length = np.linalg.norm(line_j[0] - line_j[1])
+                # length = scores[j].cpu().numpy()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        # 如果找到有效的线段,则添加到结果中
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)  # 使用线段长度作为分数
+
+        else:
+            valid_lines.append([[0.0,0.0],[0.0,0.0]])
+            valid_scores.append(0.0)  # 使用线段置信度作为分数
+        # print(f'valid_lines:{valid_lines}')
+        # print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+def box_line_optimized_parallel(pred):
+    """并行处理边界框和线段的匹配"""
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # 所有box
+    num_processes = min(mp.cpu_count(), len(boxes))  # 使用可用的核心数
+    with mp.Pool(processes=num_processes) as pool:
+        results = pool.starmap(
+            process_box,
+            [(box, lines, scores) for box in boxes]
+        )
+    # 更新预测结果
+    filtered_pred = []
+    for idx_box, (valid_lines, valid_scores) in enumerate(results):
+        if valid_lines:
+            pred[idx_box]['line'] = torch.tensor(valid_lines)
+            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
+            filtered_pred.append(pred[idx_box])
+    return filtered_pred
+
+
+def predict(image_path):
+
+    start_time = time.time()
+
+    # 保存上传文件
+    # os.makedirs("uploaded_images", exist_ok=True)
+    # image_path = f"{file.filename}"
+    # with open(image_path, "wb") as f:
+    # shutil.copyfileobj(file.file, f)
+
+    # 加载模型
+    model_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
+    model = load_model(model_path)
+
+    # 预处理图片
+    img_tensor,img = preprocess_image(image_path)
+    # W = img.shape[0]
+    im = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+    # 模型推理
+    with torch.no_grad():
+      predictions = model([img_tensor.to(device)])
+
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    predictions[-1]['wires']['lines_'] = torch.from_numpy(nlines).float().cuda()
+    predictions[-1]['wires']['score_'] = torch.from_numpy(nscores).float().cuda()
+    print(predictions)
+
+    # 匹配线段和边界框
+    t_start = time.time()
+    filtered_pred = box_line_optimized_parallel(predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+
+    # 绘制图像
+    output_path_box = show_box(im, predictions, t_start)
+    output_path_line = show_line(im, predictions, t_start)
+    output_path_boxandline = show_predict(im, filtered_pred, t_start)
+
+    # 合并图像
+    combined_image_path = "combined_result.png"
+    combine_images(
+        [output_path_boxandline, output_path_box, output_path_line],
+        titles=["Box and Line", "Box", "Line"],
+        output_path=combined_image_path
+    )
+
+    end_time = time.time()
+    print(f'Total time: {end_time - start_time:.2f} seconds')
+
+    # 获取线段数据并添加详细的调试信息
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * np.array([1328, 2112])
+    print(f"Initial lines shape: {lines.shape}")
+    print(f"Initial lines data type: {lines.dtype}")
+
+    # 确保数据是正确的形状
+    if len(lines.shape) != 3: # 如果不是 (N, 2, 2) 形状
+      if len(lines.shape) == 2 and lines.shape[1] == 4:
+        # 如果是 (N, 4) 形状,重塑为 (N, 2, 2)
+            lines = lines.reshape(-1, 2, 2)
+      else:
+            print(f"Warning: Unexpected lines shape: {lines.shape}")
+
+    print(f"After reshape - lines shape: {lines.shape}")
+
+    # 确保每个点是 [x, y] 格式
+    formatted_lines = []
+    for line in lines:
+        start_point = np.array([line[0][0], line[0][1]])
+        end_point = np.array([line[1][0], line[1][1]])
+        formatted_lines.append([start_point, end_point])
+
+    formatted_lines = np.array(formatted_lines)
+    print(f"Final formatted_lines shape: {formatted_lines.shape}")
+    print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
+
+     # 确保返回的是三维数组:[lines_array]
+    result = [formatted_lines]
+    print(f"Final result type: {type(result)}")
+    print(f"Final result[0] shape: {result[0].shape}")
+
+    return result
+
+def combine_images(image_paths: list, titles: list, output_path: str):
+    """将多个图像合并为一张图片"""
+    fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
+    for ax, img_path, title in zip(axes, image_paths, titles):
+        ax.imshow(plt.imread(img_path))
+        ax.set_title(title)
+        ax.axis("off")
+    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
+    plt.close()
+
+def show_box(im, predictions, t_start):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0.7:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    print(f'show_box used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_box.png"
+    save_plot(output_path)
+    return output_path
+
+def show_line(im, predictions, t_start):
+    """绘制线段并保存结果"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.9:
+            continue
+        ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    t_end = time.time()
+    print(f'show_line used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_line.png"
+    save_plot(output_path)
+    return output_path
+
+def show_predict(im, filtered_pred, t_start):
+    """绘制匹配后的边界框和线段并保存结果"""
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, pred in enumerate(filtered_pred):
+        boxes = pred['boxes'].cpu().numpy()
+        box_scores = pred['scores'].cpu().numpy()
+        lines = pred['line'].cpu().numpy()
+        line_scores = pred['line_score'].cpu().numpy()
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+            if box_score < 0.7 or line_score < 0.9:
+                continue
+
+            # 如果线段为空(即没有找到有效线段),跳过绘制
+            if line is None or len(line) == 0:
+                continue
+
+            x0, y0, x1, y1 = box
+            a, b = line
+            color = colors[(idx + box_idx) % len(colors)]  # 每个边界框分配一个唯一颜色
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c=color, s=10)
+            ax.scatter(b[1], b[0], c=color, s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+if __name__ == "__main__":
+    lines = predict(r'C:\Users\m2337\Desktop\p\9.jpg')
+    print(f'lines:{lines}')

BIN
models/line_detect/combined_result.png


+ 320 - 0
models/line_detect/main_lm_0223.py

@@ -0,0 +1,320 @@
+from fastapi import FastAPI, File, UploadFile, HTTPException
+from fastapi.responses import FileResponse
+from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_detect.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+from fastapi.middleware.cors import CORSMiddleware
+# ??? FastAPI
+app = FastAPI()
+
+# ?? CORS ???
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],  # ?????
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+# ?????????? 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """???????????"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """????????"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1)
+
+def save_plot(output_path: str):
+    """?????????"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+def get_colors():
+    """????????????"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+# def process_box(box, lines, scores):
+#     """?????????????????"""
+#     valid_lines = []  # ???????
+#     valid_scores = []  # ???????
+#     for i in box:
+#         best_line = None
+#         max_length = 0.0
+#         # ??????
+#         for j in range(lines.shape[1]):
+#             line_j = lines[0, j].cpu().numpy() / 128 * 512
+#             # ?????????box?
+#             if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+#                     line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+#                     line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+#                     line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+#                 # ??????
+#                 length = np.linalg.norm(line_j[0] - line_j[1])
+#                 if length > max_length:
+#                     best_line = line_j
+#                     max_length = length
+#         # ?????????????????
+#         if best_line is not None:
+#             valid_lines.append(best_line)
+#             valid_scores.append(max_length)  # ??????????
+#     return valid_lines, valid_scores
+def process_box(box, lines, scores):
+    """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
+    valid_lines = []  # ´æ´¢ÓÐЧµÄÏß¶Î
+    valid_scores = []  # ´æ´¢ÓÐЧµÄ·ÖÊý
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # ±éÀúËùÓÐÏß¶Î
+        for j in range(lines.shape[1]):
+            line_j = lines[0, j].cpu().numpy() / 128 * 512
+            # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # length = np.linalg.norm(line_j[0] - line_j[1])
+                length = scores[j].item()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)
+
+        else:
+            valid_lines.append([[0.0,0.0],[0.0,0.0]])
+            valid_scores.append(0.0)
+        # print(f'valid_lines:{valid_lines}')
+        # print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+def box_line_optimized_parallel(pred):
+    """?????????????"""
+    lines = pred[-1]['wires']['lines']  # ???[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # ?????[2500]
+    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # ??box
+    num_processes = min(mp.cpu_count(), len(boxes))  # ????????
+    with mp.Pool(processes=num_processes) as pool:
+        results = pool.starmap(
+            process_box,
+            [(box, lines, scores) for box in boxes]
+        )
+    # ??????
+    filtered_pred = []
+    for idx_box, (valid_lines, valid_scores) in enumerate(results):
+        if valid_lines:
+            pred[idx_box]['line'] = torch.tensor(valid_lines)
+            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
+            filtered_pred.append(pred[idx_box])
+    return filtered_pred
+
+@app.get("/")
+def read_root():
+    """??????"""
+    return FileResponse("static/index.html")
+
+@app.post("/predict")
+@app.post("/predict/")
+async def predict(file: UploadFile = File(...)):
+    try:
+        start_time = time.time()
+
+        # ??????
+        os.makedirs("uploaded_images", exist_ok=True)
+        image_path = f"uploaded_images/{file.filename}"
+        with open(image_path, "wb") as f:
+            shutil.copyfileobj(file.file, f)
+
+        # ????
+        model_path = "/data/share/rlq/weights/linenet_wts/resnet50_best_e280.pth"
+        model = load_model(model_path)
+
+        # ?????
+        img_tensor = preprocess_image(image_path)
+        im = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+        # ????
+        with torch.no_grad():
+            predictions = model([img_tensor.to(device)])
+
+        # ????????
+        t_start = time.time()
+        filtered_pred = box_line_optimized_parallel(predictions)
+        t_end = time.time()
+        print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+
+        # ????
+        output_path_box = show_box(im, predictions, t_start)
+        output_path_line = show_line(im, predictions, t_start)
+        output_path_boxandline = show_predict(im, filtered_pred, t_start)
+
+        # # ????
+        # combined_image_path = "combined_result.png"
+        # combine_images(
+        #     [output_path_boxandline],
+        #     # [output_path_boxandline, output_path_box, output_path_line],
+        #     titles=["Box and Line"],
+        #     output_path=combined_image_path
+        # )
+
+        end_time = time.time()
+        print(f'Total time: {end_time - start_time:.2f} seconds')
+
+        # ???????
+        return FileResponse(output_path_boxandline, media_type="image/png", filename="result.png")
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=str(e))
+
+def combine_images(image_paths: list, titles: list, output_path: str):
+    """????????????"""
+    fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
+    for ax, img_path, title in zip(axes, image_paths, titles):
+        ax.imshow(plt.imread(img_path))
+        ax.set_title(title)
+        ax.axis("off")
+    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
+    plt.close()
+
+def show_box(im, predictions, t_start):
+    """??????????"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0.7:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    print(f'show_box used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_box.png"
+    save_plot(output_path)
+    return output_path
+
+def show_line(im, predictions, t_start):
+    """?????????"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.9:
+            continue
+        ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    t_end = time.time()
+    print(f'show_line used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_line.png"
+    save_plot(output_path)
+    return output_path
+
+def show_predict(im, pred, t_start):
+    """?????????????????"""
+    col = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0.7 or line_score >= 0.9:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    # for idx, pred in enumerate(filtered_pred):
+    #     boxes = pred['boxes'].cpu().numpy()
+    #     box_scores = pred['scores'].cpu().numpy()
+    #     lines = pred['line'].cpu().numpy()
+    #     line_scores = pred['line_score'].cpu().numpy()
+    #     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    #     nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    #     for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+    #         if box_score > 0.7 and line_score > 0.9:
+    #             x0, y0, x1, y1 = box
+    #             a, b = line
+    #             color = colors[(idx + box_idx) % len(colors)]  # ?????????????
+    #             ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+    #             ax.scatter(a[1], a[0], c=color, s=10)
+    #             ax.scatter(b[1], b[0], c=color, s=10)
+    #             ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=8001)

+ 37 - 2
models/line_detect/postprocess.py

@@ -157,10 +157,15 @@ def show_predict(imgs, pred, t_start):
     ax.imshow(np.array(im))
     idx = 0
 
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
     for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
         x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
         a, b = line
-        if box_score > 0.7 and line_score > 0.9:
+        if box_score >= 0.7 or line_score >= 0.9:
             ax.add_patch(
                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
             ax.scatter(a[1], a[0], c='#871F78', s=10)
@@ -176,12 +181,17 @@ def show_predict(imgs, pred, t_start):
 # 下面的都没有进行box与line的一一匹配
 # 只画线,设阈值
 def show_line(imgs, pred, t_start):
-
     im = imgs.permute(1, 2, 0)
     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     # print(pred[-1]['wires']['score'])
     line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
 
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
+
+    # count = np.sum(line_score > 0.9)
+    # print(f'draw line number:{count}')
+
     # 可视化预测结
     fig, ax = plt.subplots(figsize=(10, 10))
     ax.imshow(np.array(im))
@@ -199,6 +209,31 @@ def show_line(imgs, pred, t_start):
     plt.show()
 
 
+# show_line优化
+def show_line_optimized(imgs, pred, t_start):
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+
+    for i, t in enumerate([0.9]):
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+
+    t_end = time.time()
+    print(f'show_line_optimized used:{t_end - t_start}')
+
+    plt.show()
+
+
 # 只画框,设阈值
 def show_box(imgs, pred, t_start):
     col = [

+ 1 - 1
models/line_detect/predict.py

@@ -128,5 +128,5 @@ if __name__ == '__main__':
     pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
     # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
     # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
-    img_path = r'C:\Users\m2337\Desktop\49.jpg'
+    img_path = r'C:\Users\m2337\Desktop\p\24.jpg'
     predict(pt_path, model, img_path)

+ 49 - 12
models/line_detect/predict2.py

@@ -2,7 +2,8 @@ import time
 
 import skimage
 
-from models.line_detect.postprocess import show_predict, show_line, show_box, show_box_or_line, show_box_and_line
+from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \
+    show_line_optimized, show_line
 import os
 
 import torch
@@ -34,11 +35,21 @@ def load_best_model(model, save_path, device):
     return model
 
 
-def box_line_(pred):
+def box_line_(imgs, pred):
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
     for idx, box_ in enumerate(pred[0:-1]):
         box = box_['boxes']  # 是一个tensor
         line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
         score = pred[-1]['wires']['score'][idx]
+        #
+        # diag = (512 ** 2 + 512 ** 2) ** 0.5
+        # lines, scores = postprocess(line, score, diag * 0.01, 0, False)
+
         line_ = []
         score_ = []
 
@@ -142,28 +153,54 @@ def predict(pt_path, model, img):
         predictions = model([img_.to(device)])
         # print(predictions)
 
-    show_line(img_, predictions, t_start)   # 只画线
+    # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    # scores = predictions[-1]['wires']['score'][0].cpu().numpy() / 128 * 512
+    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    # nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+    # print(len(nlines))
+
+    # arr = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    # unique_subarrays = set()
+    #
+    # for i in range(arr.shape[0]):
+    #     for j in range(arr.shape[1]):
+    #         subarray = arr[i, j]
+    #         # 确保 subarray 是一个二维数组
+    #         if subarray.shape != (2,):
+    #             raise ValueError(f"Unexpected shape of subarray at index [{i}, {j}]: {subarray.shape}, expected (2,)")
+    #
+    #         subarray_tuple = tuple(subarray.tolist())
+    #         unique_subarrays.add(subarray_tuple)
+    #
+    # # 计算唯一子数组的数量
+    # num_unique_subarrays = len(unique_subarrays)
+    # print(f"共有 {num_unique_subarrays} 个不同的 [2, 2] 子数组")
+
+    # show_line_optimized(img_, predictions, t_start)   # 只画线
+    show_line(img_, predictions, t_start)
     # show_box(img_, predictions, t_start)   # 只画kuang
     # show_box_or_line(img_, predictions, show_line=True, show_box=True)   # 参数确定画什么
     # show_box_and_line(img_, predictions, show_line=True, show_box=True)  # 一起画 1x2 2张图
 
-    # t_start = time.time()
-    # # pred = box_line_optimized(predictions)
-    # pred = box_line_(predictions)
-    # t_end = time.time()
-    # print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+    t_start = time.time()
+    # pred = box_line_optimized(predictions)
+    pred = box_line_(img_, predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
 
-    # show_predict(img_, pred, t_start)
+    show_predict(img_, pred, t_start)
 
 
 if __name__ == '__main__':
     t_start = time.time()
     print(f'start to predict:{t_start}')
     model = linenet_resnet50_fpn().to(device)
-    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
-    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
+    # pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
+    # pt_path = r'D:\python\PycharmProjects\linenet_wts\r50fpn_wts_e350\best.pth'
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
+    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-43-13_SaveImage.png'  # 工件图
     # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
-    img_path = r'C:\Users\m2337\Desktop\9.jpg'
+    img_path = r'C:\Users\m2337\Desktop\p\2025-01-03-09-34-32_SaveImage_adjust_brightness_contrast.jpg'
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')

+ 1 - 1
models/line_detect/predict3.py

@@ -215,7 +215,7 @@ if __name__ == '__main__':
     pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
     # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
     # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
-    img_path = r'C:\Users\m2337\Desktop\49.jpg'
+    img_path = r'C:\Users\m2337\Desktop\p\49.jpg'
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')

+ 274 - 0
models/line_detect/predict_0221.py

@@ -0,0 +1,274 @@
+# zjf的是将线段长度作为得分,这里是将最大置信度作为得分
+
+# from fastapi import FastAPI, File, UploadFile, HTTPException
+# from fastapi.responses import FileResponse
+# from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_detect.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+# from fastapi.middleware.cors import CORSMiddleware
+# 初始化 FastAPI
+# app = FastAPI()
+
+# 添加 CORS 中间件
+# app.add_middleware(
+#     CORSMiddleware,
+#     allow_origins=["*"],  # 允许所有源
+#     allow_credentials=True,
+#     allow_methods=["*"],
+#     allow_headers=["*"],
+# )
+
+# 设置多进程启动方式为 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """加载模型并返回模型实例"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """预处理上传的图片"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1)
+
+def save_plot(output_path: str):
+    """保存图像并关闭绘图"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+def get_colors():
+    """返回一组预定义的颜色列表"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def process_box(box, lines, scores):
+    """处理单个边界框,找到最佳匹配的线段"""
+    valid_lines = []  # 存储有效的线段
+    valid_scores = []  # 存储有效的分数
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # 遍历所有线段
+        for j in range(lines.shape[1]):
+            line_j = lines[0, j].cpu().numpy() / 128 * 512
+            # 检查线段是否完全在box内
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # length = np.linalg.norm(line_j[0] - line_j[1])
+                length = scores[j].item()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)
+
+        else:
+            valid_lines.append([[0.0,0.0],[0.0,0.0]])
+            valid_scores.append(0.0)  # 使用线段置信度作为分数
+        # print(f'valid_lines:{valid_lines}')
+        # print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+def box_line_optimized_parallel(pred):
+    """并行处理边界框和线段的匹配"""
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # 所有box
+    num_processes = min(mp.cpu_count(), len(boxes))  # 使用可用的核心数
+    with mp.Pool(processes=num_processes) as pool:
+        results = pool.starmap(
+            process_box,
+            [(box, lines, scores) for box in boxes]
+        )
+    # 更新预测结果
+    filtered_pred = []
+    for idx_box, (valid_lines, valid_scores) in enumerate(results):
+        if valid_lines:
+            pred[idx_box]['line'] = torch.tensor(valid_lines)
+            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
+            filtered_pred.append(pred[idx_box])
+    return filtered_pred
+
+
+def predict(image_path):
+
+    start_time = time.time()
+
+    # 保存上传文件
+    # os.makedirs("uploaded_images", exist_ok=True)
+    # image_path = f"{file.filename}"
+    # with open(image_path, "wb") as f:
+    #     shutil.copyfileobj(file.file, f)
+
+    # 加载模型
+    model_path = "/data/lm/resnet50_best_e280.pth"
+    model = load_model(model_path)
+
+    # 预处理图片
+    img_tensor = preprocess_image(image_path)
+    im = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+    # 模型推理
+    with torch.no_grad():
+        predictions = model([img_tensor.to(device)])
+
+    # 匹配线段和边界框
+    t_start = time.time()
+    filtered_pred = box_line_optimized_parallel(predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+
+    # 绘制图像
+    output_path_box = show_box(im, predictions, t_start)
+    output_path_line = show_line(im, predictions, t_start)
+    output_path_boxandline = show_predict(im, filtered_pred, t_start)
+
+    # 合并图像
+    combined_image_path = "combined_result.png"
+    combine_images(
+        [output_path_boxandline, output_path_box, output_path_line],
+        titles=["Box and Line", "Box", "Line"],
+        output_path=combined_image_path
+    )
+
+    end_time = time.time()
+    print(f'Total time: {end_time - start_time:.2f} seconds')
+
+    #     # 返回合成的图片
+    #     return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
+    # except Exception as e:
+    #     raise HTTPException(status_code=500, detail=str(e))
+
+def combine_images(image_paths: list, titles: list, output_path: str):
+    """将多个图像合并为一张图片"""
+    fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
+    for ax, img_path, title in zip(axes, image_paths, titles):
+        ax.imshow(plt.imread(img_path))
+        ax.set_title(title)
+        ax.axis("off")
+    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
+    plt.close()
+
+def show_box(im, predictions, t_start):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0.7:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    print(f'show_box used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_box.png"
+    save_plot(output_path)
+    return output_path
+
+def show_line(im, predictions, t_start):
+    """绘制线段并保存结果"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.9:
+            continue
+        ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    t_end = time.time()
+    print(f'show_line used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_line.png"
+    save_plot(output_path)
+    return output_path
+
+def show_predict(im, filtered_pred, t_start):
+    """绘制匹配后的边界框和线段并保存结果"""
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, pred in enumerate(filtered_pred):
+        boxes = pred['boxes'].cpu().numpy()
+        box_scores = pred['scores'].cpu().numpy()
+        lines = pred['line'].cpu().numpy()
+        line_scores = pred['line_score'].cpu().numpy()
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+            if box_score < 0.7 or line_score < 0.9:
+                continue
+
+            # 如果线段为空(即没有找到有效线段),跳过绘制
+            if line is None or len(line) == 0:
+                continue
+
+            x0, y0, x1, y1 = box
+            a, b = line
+            color = colors[(idx + box_idx) % len(colors)]  # 每个边界框分配一个唯一颜色
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c=color, s=10)
+            ax.scatter(b[1], b[0], c=color, s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+if __name__ == "__main__":
+    predict('/data/lm/wirenet_1000/images/train/00031643_1.png')

+ 116 - 0
models/line_detect/predict_lm.py

@@ -0,0 +1,116 @@
+import time
+
+import skimage
+
+from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \
+    show_line_optimized, show_line
+import os
+
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision import transforms
+
+# from models.wirenet.postprocess import postprocess
+from models.wirenet.postprocess import postprocess
+from rtree import index
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def load_best_model(model, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        # if optimizer is not None:
+        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model
+
+
+def box_line_(imgs, pred):
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+            score_.append(score_max)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
+    return pred
+
+
+def predict(pt_path, model, img):
+    model = load_best_model(model, pt_path, device)
+
+    model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)  # [3, 512, 512]
+
+    # img_ = img_tensor
+
+    # 将图像调整为512x512大小
+    t_start = time.time()
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_ = torch.tensor(im_resized).permute(2, 0, 1)
+    t_end = time.time()
+    print(f'switch img used:{t_end - t_start}')
+
+    with torch.no_grad():
+        predictions = model([img_.to(device)])
+        # print(predictions)
+
+    t_start = time.time()
+    pred = box_line_(img_, predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+
+    show_predict(img_, pred, t_start)
+
+
+if __name__ == '__main__':
+    t_start = time.time()
+    print(f'start to predict:{t_start}')
+    model = linenet_resnet50_fpn().to(device)
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
+    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-43-13_SaveImage.png'  # 工件图
+    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    img_path = r'C:\Users\m2337\Desktop\p\9.jpg'
+    predict(pt_path, model, img_path)
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')

+ 3 - 4
models/line_detect/predict_zjf.py → models/line_detect/predict_zjf(1).py

@@ -1,3 +1,4 @@
+
 import time
 import skimage
 from models.line_detect.postprocess import show_predict
@@ -15,7 +16,7 @@ import multiprocessing as mp
 mp.set_start_method('spawn', force=True)
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
+print(f'device:{device}')
 def load_best_model(model, save_path, device):
     if os.path.exists(save_path):
         checkpoint = torch.load(save_path, map_location=device)
@@ -128,9 +129,7 @@ if __name__ == '__main__':
     print(f'Start to predict: {t_start}')
     model = linenet_resnet50_fpn().to(device)
     pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
-    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
-    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
-    img_path = r'C:\Users\m2337\Desktop\9.jpg'
+    img_path = r'C:\Users\m2337\Desktop\p\49.jpg'
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'Total prediction time: {t_end - t_start:.4f} seconds')

+ 2 - 2
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: logs/
-  datadir: D:\python\PycharmProjects\data
-#  datadir: I:\datasets\wirenet_1000
+#  datadir: D:\python\PycharmProjects\data
+  datadir: D:\python\PycharmProjects\data_20250223\lcnn_20250223
   resume_from:
   num_workers: 8
   tensorboard_port: 6000

+ 276 - 0
models/line_detect/zjf_0221.py

@@ -0,0 +1,276 @@
+# 0221 将线段长度作为线段得分
+
+# from fastapi import FastAPI, File, UploadFile, HTTPException
+# from fastapi.responses import FileResponse
+# from fastapi.staticfiles import StaticFiles
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_detect.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+# from fastapi.middleware.cors import CORSMiddleware
+# 初始化 FastAPI
+# app = FastAPI()
+
+# 添加 CORS 中间件
+# app.add_middleware(
+#     CORSMiddleware,
+#     allow_origins=["*"],  # 允许所有源
+#     allow_credentials=True,
+#     allow_methods=["*"],
+#     allow_headers=["*"],
+# )
+
+# 设置多进程启动方式为 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """加载模型并返回模型实例"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """预处理上传的图片"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1)
+
+def save_plot(output_path: str):
+    """保存图像并关闭绘图"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+def get_colors():
+    """返回一组预定义的颜色列表"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def process_box(box, lines, scores):
+    """处理单个边界框,找到最佳匹配的线段"""
+    valid_lines = []  # 存储有效的线段
+    valid_scores = []  # 存储有效的分数
+    print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # 遍历所有线段
+        for j in range(lines.shape[1]):
+            line_j = lines[0, j].cpu().numpy() / 128 * 512
+            # 检查线段是否完全在box内
+            if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
+                    line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                    line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+                # 计算线段长度
+                length = np.linalg.norm(line_j[0] - line_j[1])
+                # length = scores[j].cpu().numpy()
+                print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        # 如果找到有效的线段,则添加到结果中
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)  # 使用线段长度作为分数
+
+        else:
+            valid_lines.append([[0.0,0.0],[0.0,0.0]])
+            valid_scores.append(0.0)  # 使用线段置信度作为分数
+        print(f'valid_lines:{valid_lines}')
+        print(f'valid_scores:{valid_scores}')
+    return valid_lines, valid_scores
+
+def box_line_optimized_parallel(pred):
+    """并行处理边界框和线段的匹配"""
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # 所有box
+    num_processes = min(mp.cpu_count(), len(boxes))  # 使用可用的核心数
+    with mp.Pool(processes=num_processes) as pool:
+        results = pool.starmap(
+            process_box,
+            [(box, lines, scores) for box in boxes]
+        )
+    # 更新预测结果
+    filtered_pred = []
+    for idx_box, (valid_lines, valid_scores) in enumerate(results):
+        if valid_lines:
+            pred[idx_box]['line'] = torch.tensor(valid_lines)
+            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
+            filtered_pred.append(pred[idx_box])
+    return filtered_pred
+
+
+def predict(image_path):
+
+    start_time = time.time()
+
+    # 保存上传文件
+    # os.makedirs("uploaded_images", exist_ok=True)
+    # image_path = f"{file.filename}"
+    # with open(image_path, "wb") as f:
+    #     shutil.copyfileobj(file.file, f)
+
+    # 加载模型
+    model_path = "/home/dieu/PycharmProjects/MultiVisionModels/logs/pth/resnet50_best_e100.pth"
+    model = load_model(model_path)
+
+    # 预处理图片
+    img_tensor = preprocess_image(image_path)
+    im = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+    # 模型推理
+    with torch.no_grad():
+        predictions = model([img_tensor.to(device)])
+
+    # 匹配线段和边界框
+    t_start = time.time()
+    filtered_pred = box_line_optimized_parallel(predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+
+    # 绘制图像
+    output_path_box = show_box(im, predictions, t_start)
+    output_path_line = show_line(im, predictions, t_start)
+    output_path_boxandline = show_predict(im, filtered_pred, t_start)
+
+    # 合并图像
+    combined_image_path = "combined_result.png"
+    combine_images(
+        [output_path_boxandline, output_path_box, output_path_line],
+        titles=["Box and Line", "Box", "Line"],
+        output_path=combined_image_path
+    )
+
+    end_time = time.time()
+    print(f'Total time: {end_time - start_time:.2f} seconds')
+
+    #     # 返回合成的图片
+    #     return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
+    # except Exception as e:
+    #     raise HTTPException(status_code=500, detail=str(e))
+
+def combine_images(image_paths: list, titles: list, output_path: str):
+    """将多个图像合并为一张图片"""
+    fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
+    for ax, img_path, title in zip(axes, image_paths, titles):
+        ax.imshow(plt.imread(img_path))
+        ax.set_title(title)
+        ax.axis("off")
+    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
+    plt.close()
+
+def show_box(im, predictions, t_start):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0.7:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    print(f'show_box used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_box.png"
+    save_plot(output_path)
+    return output_path
+
+def show_line(im, predictions, t_start):
+    """绘制线段并保存结果"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.9:
+            continue
+        ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    t_end = time.time()
+    print(f'show_line used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_line.png"
+    save_plot(output_path)
+    return output_path
+
+def show_predict(im, filtered_pred, t_start):
+    """绘制匹配后的边界框和线段并保存结果"""
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, pred in enumerate(filtered_pred):
+        boxes = pred['boxes'].cpu().numpy()
+        box_scores = pred['scores'].cpu().numpy()
+        lines = pred['line'].cpu().numpy()
+        line_scores = pred['line_score'].cpu().numpy()
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+        for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+            if box_score < 0.7 or line_score < 0.9:
+                continue
+
+            # 如果线段为空(即没有找到有效线段),跳过绘制
+            if line is None or len(line) == 0:
+                continue
+
+            x0, y0, x1, y1 = box
+            a, b = line
+            color = colors[(idx + box_idx) % len(colors)]  # 每个边界框分配一个唯一颜色
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c=color, s=10)
+            ax.scatter(b[1], b[0], c=color, s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+    t_end = time.time()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+if __name__ == "__main__":
+    predict('/home/dieu/PycharmProjects/MultiVisionModels/00030043_0.png')

+ 81 - 0
split_dataset.py

@@ -0,0 +1,81 @@
+import os
+import random
+import shutil
+from sklearn.model_selection import train_test_split
+
+# 定义路径
+data_dir = 'D:\python\PycharmProjects\data_20250223\lcnn新T型板十字板增强后(复件)'  # 替换为你的数据文件夹路径
+output_dir = 'D:\python\PycharmProjects\data_20250223\lcnn_20250223'  # 替换为你想要保存输出的文件夹路径
+
+# 创建输出目录结构
+images_train_dir = os.path.join(output_dir, 'images', 'train')
+images_val_dir = os.path.join(output_dir, 'images', 'val')
+labels_train_dir = os.path.join(output_dir, 'labels', 'train')
+labels_val_dir = os.path.join(output_dir, 'labels', 'val')
+
+os.makedirs(images_train_dir, exist_ok=True)
+os.makedirs(images_val_dir, exist_ok=True)
+os.makedirs(labels_train_dir, exist_ok=True)
+os.makedirs(labels_val_dir, exist_ok=True)
+
+# 获取所有图片文件名和对应的json文件名
+image_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
+json_files = {f.replace('.json', ''): f for f in os.listdir(data_dir) if f.endswith('.json')}
+
+# 提取图片名称(不包含扩展名)以便匹配json文件
+image_names = [os.path.splitext(f)[0] for f in image_files]
+
+# 按照9:1的比例划分数据集
+train_names, val_names = train_test_split(image_names, test_size=0.1, random_state=42)
+
+# 复制文件到相应目录
+for name in train_names:
+    image_file = name + '.jpg'
+    json_file = json_files[name]
+
+    # 复制图片文件
+    shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_train_dir, image_file))
+
+    # 复制json文件
+    shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_train_dir, json_file))
+
+for name in val_names:
+    image_file = name + '.jpg'
+    json_file = json_files[name]
+
+    # 复制图片文件
+    shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_val_dir, image_file))
+
+    # 复制json文件
+    shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_val_dir, json_file))
+
+# 示例调用
+if __name__ == "__main__":
+    # 确保输出目录和子目录存在
+    os.makedirs(images_train_dir, exist_ok=True)
+    os.makedirs(images_val_dir, exist_ok=True)
+    os.makedirs(labels_train_dir, exist_ok=True)
+    os.makedirs(labels_val_dir, exist_ok=True)
+
+    # 执行划分和文件复制
+    train_names, val_names = train_test_split(image_names, test_size=0.1, random_state=42)
+
+    for name in train_names:
+        image_file = name + '.jpg'
+        json_file = json_files[name]
+
+        # 复制图片文件
+        shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_train_dir, image_file))
+
+        # 复制json文件
+        shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_train_dir, json_file))
+
+    for name in val_names:
+        image_file = name + '.jpg'
+        json_file = json_files[name]
+
+        # 复制图片文件
+        shutil.copy(os.path.join(data_dir, image_file), os.path.join(images_val_dir, image_file))
+
+        # 复制json文件
+        shutil.copy(os.path.join(data_dir, json_file), os.path.join(labels_val_dir, json_file))