Browse Source

first commit

xue50 2 months ago
parent
commit
16a1c3655c

+ 225 - 69
aaa.py

@@ -1,77 +1,233 @@
+# import multiprocessing
+#
+#
+# def worker(num):
+#     """线程调用的函数"""
+#     print(f"Worker: {num}")
+#     return
+#
+#
+# if __name__ == '__main__':
+#     jobs = []
+#     for i in range(5):
+#         p = multiprocessing.Process(target=worker, args=(i,))
+#         jobs.append(p)
+#         p.start()
+#
+# import sys
+# print(sys.version)
+
+from typing import List
+from fastapi import FastAPI, File, UploadFile, HTTPException
+from fastapi.responses import FileResponse
+import os
 import torch
-from torchvision.utils import draw_bounding_boxes
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
 from torchvision import transforms
+import shutil
 import matplotlib.pyplot as plt
-import numpy as np
-
-
-def c(score):
-    # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
-    return (1, 0, 0) if score > 0.9 else (0, 1, 0)
-
-
-def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
-    # 假设的后处理函数,用于过滤线段
-    nlines = []
-    nscores = []
-    for line, score in zip(lines, scores):
-        if score >= min_score:
-            nlines.append(line)
-            nscores.append(score)
-    return np.array(nlines), np.array(nscores)
-
-
-def show_line(img, pred, epoch, writer):
-    im = img.permute(1, 2, 0).cpu().numpy()
-
-    # 绘制边界框
-    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
-                                      colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
-
-    H = pred[-1]['wires']
-    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
-    scores = H["score"][0].cpu().numpy()
-
-    print(f"Lines before deduplication: {len(lines)}")
-
-    # 移除重复的线段
-    for i in range(1, len(lines)):
-        if (lines[i] == lines[0]).all():
-            lines = lines[:i]
-            scores = scores[:i]
-            break
-
-    print(f"Lines after deduplication: {len(lines)}")
-
-    # 后处理线段以移除重叠的线段
-    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
-
-    print(f"Lines after postprocessing: {len(nlines)}")
-
-    # 创建一个新的图像并绘制线段和边界框
-    fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
-    ax.imshow(boxed_image)
-    ax.set_axis_off()
-    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-    plt.margins(0, 0)
-    plt.gca().xaxis.set_major_locator(plt.NullLocator())
-    plt.gca().yaxis.set_major_locator(plt.NullLocator())
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_detect.postprocess import postprocess
+import time
+import multiprocessing as mp
+from fastapi.middleware.cors import CORSMiddleware
+# 初始化 FastAPI
+app = FastAPI()
+
+# 添加 CORS 中间件
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],  # 允许所有源
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+# 设置多进程启动方式为 'spawn'
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+def load_model(model_path):
+    """加载模型并返回模型实例"""
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+def preprocess_image(image_path):
+    """预处理上传的图片"""
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1)
+
+def save_plot(output_path: str):
+    """保存图像并关闭绘图"""
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
 
-    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+def get_colors():
+    """返回一组预定义的颜色列表"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def show_line(im, lines, line_scores, diag, t_start):
+    """绘制线段并保存结果"""
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
     for (a, b), s in zip(nlines, nscores):
-        if s < 0.85:  # 调整阈值以筛选显示的线段
+        if s < 0.9:
+            continue
+        ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+
+    t_end = time.time()
+    print(f'show_line used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_line.png"
+    save_plot(output_path)
+    return output_path
+
+def show_box(im, boxes, box_scores, colors, t_start):
+    """绘制边界框并保存结果"""
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0.7:
             continue
-        plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
-        plt.scatter(a[1], a[0], **PLTOPTS)
-        plt.scatter(b[1], b[0], **PLTOPTS)
-
-    plt.tight_layout()
-    fig.canvas.draw()
-    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
-        fig.canvas.get_width_height()[::-1] + (3,))
+        x0, y0, x1, y1 = box
+        ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+
+    t_end = time.time()
+    print(f'show_box used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result_box.png"
+    save_plot(output_path)
+    return output_path
+
+def show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start):
+    """绘制边界框和线段并保存结果"""
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for box, line, box_score, line_score, color in zip(boxes, nlines, box_scores, nscores, colors):
+        if box_score > 0.7 and line_score > 0.9:
+            x0, y0, x1, y1 = box
+            a, b = line
+            ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+
+    t_end = time.time()
+    print(f'show_predict used: {t_end - t_start:.2f} seconds')
+    output_path = "temp_result.png"
+    save_plot(output_path)
+    return output_path
+
+@app.get("/")
+def read_root():
+    """返回前端页面"""
+    return FileResponse("static/index.html")
+
+@app.post("/predict")
+@app.post("/predict/")
+async def predict(file: UploadFile = File(...)):
+    try:
+        start_time = time.time()
+
+        # 保存上传文件
+        os.makedirs("uploaded_images", exist_ok=True)
+        image_path = f"uploaded_images/{file.filename}"
+        with open(image_path, "wb") as f:
+            shutil.copyfileobj(file.file, f)
+
+        # 加载模型
+        # model_path = "/home/dieu/PycharmProjects/MultiVisionModels/logs/pth/resnet50_best_e100.pth"
+        model_path = r'D:\python\PycharmProjects\20250214\weight\merged_model_weights.pth'
+        model = load_model(model_path)
+
+        # 预处理图片
+        img_tensor = preprocess_image(image_path)
+        im = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+        # 模型推理
+        with torch.no_grad():
+            predictions = model([img_tensor.to(device)])
+
+        # 提取预测结果
+        boxes = predictions[0]['boxes'].cpu().numpy()
+        box_scores = predictions[0]['scores'].cpu().numpy()
+        lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+        line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        colors = get_colors()
+
+        # 绘制图像
+        t_start = time.time()
+        output_path_box = show_box(im, boxes, box_scores, colors, t_start)
+        output_path_line = show_line(im, lines, line_scores, diag, t_start)
+        output_path_boxandline = show_predict(im, boxes, box_scores, lines, line_scores, diag, colors, t_start)
+
+        # 合并图像
+        combined_image_path = "combined_result.png"
+        combine_images(
+            [output_path_boxandline, output_path_box, output_path_line],
+            titles=["Box and Line", "Box", "Line"],
+            output_path=combined_image_path
+        )
+
+        end_time = time.time()
+        print(f'Total time: {end_time - start_time:.2f} seconds')
+
+        # 返回合成的图片
+        return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
+    except Exception as e:
+        raise HTTPException(status_code=500, detail=str(e))
+
+def combine_images(image_paths: List[str], titles: List[str], output_path: str):
+    """将多个图像合并为一张图片"""
+    fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
+    for ax, img_path, title in zip(axes, image_paths, titles):
+        ax.imshow(plt.imread(img_path))
+        ax.set_title(title)
+        ax.axis("off")
+    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
     plt.close()
-    img2 = transforms.ToTensor()(image_from_plot)
 
-    writer.add_image("output_with_boxes_and_lines", img2, epoch)
-    print("Image with boxes and lines added to TensorBoard.")
+if __name__ == "__main__":
+    import uvicorn
+    uvicorn.run(app, host="0.0.0.0", port=808)

BIN
models/line_detect/20250226_170723.png


BIN
models/line_detect/20250226_171735.png


BIN
models/line_detect/20250227_094144/box_line.png


BIN
models/line_detect/20250227_100324/box.png


+ 299 - 0
models/line_detect/lm_0223.py

@@ -0,0 +1,299 @@
+import os
+import torch
+import numpy as np
+from PIL import Image
+import skimage.io
+import skimage.color
+from torchvision import transforms
+import shutil
+import matplotlib.pyplot as plt
+from models.line_detect.line_net import linenet_resnet50_fpn
+from models.line_detect.postprocess import postprocess
+from rtree import index
+import time
+import multiprocessing as mp
+
+mp.set_start_method('spawn', force=True)
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def load_model(model_path):
+    model = linenet_resnet50_fpn().to(device)
+    if os.path.exists(model_path):
+        checkpoint = torch.load(model_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        print(f"Loaded model from {model_path}")
+    else:
+        raise FileNotFoundError(f"No saved model found at {model_path}")
+    model.eval()
+    return model
+
+
+def preprocess_image(image_path):
+    img = Image.open(image_path).convert("RGB")
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+    resized_img = skimage.transform.resize(
+        img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
+    )
+    return torch.tensor(resized_img).permute(2, 0, 1)
+
+
+def save_plot(output_path: str):
+    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
+    print(f"Saved plot to {output_path}")
+    plt.close()
+
+
+def get_colors():
+    """返回一组预定义的颜色列表"""
+    return [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+
+def process_box(box, lines, scores):
+    """处理单个边界框,找到最佳匹配的线段"""
+    valid_lines = []  # 存储有效的线段
+    valid_scores = []  # 存储有效的分数
+    # print(f'score:{len(scores)}')
+    for i in box:
+        best_line = None
+        max_length = 0.0
+        # 遍历所有线段
+        for j in range(lines.shape[1]):
+            line_j = lines[0, j].cpu().numpy() / 128 * 512
+            # 检查线段是否完全在box内
+            if (all(line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                    line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3])):
+                # length = np.linalg.norm(line_j[0] - line_j[1])
+                length = scores[j].item()
+                # print(length)
+                if length > max_length:
+                    best_line = line_j
+                    max_length = length
+        if best_line is not None:
+            valid_lines.append(best_line)
+            valid_scores.append(max_length)
+
+        else:
+            valid_lines.append([[0.0, 0.0], [0.0, 0.0]])
+            valid_scores.append(0.0)
+    return valid_lines, valid_scores
+
+
+def box_line_optimized_parallel(pred):
+    """并行处理边界框和线段的匹配"""
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+    boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]]  # 所有box
+    # num_processes = min(mp.cpu_count(), len(boxes))  # 使用可用的核心数
+    # with mp.Pool(processes=num_processes) as pool:
+    #     results = pool.starmap(
+    #         process_box,
+    #         [(box, lines, scores) for box in boxes]
+    #     )
+    results = process_box(boxes, lines, scores)
+    # 更新预测结果
+    filtered_pred = []
+    for idx_box, (valid_lines, valid_scores) in enumerate(results):
+        if valid_lines:
+            pred[idx_box]['line'] = torch.tensor(valid_lines)
+            pred[idx_box]['line_score'] = torch.tensor(valid_scores)
+            filtered_pred.append(pred[idx_box])
+    return filtered_pred
+
+
+def predict(image_path):
+    start_time = time.time()
+
+    model_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
+    model = load_model(model_path)
+
+    img_tensor = preprocess_image(image_path)
+    im = img_tensor.permute(1, 2, 0).cpu().numpy()
+
+    with torch.no_grad():
+        predictions = model([img_tensor.to(device)])
+
+    t_start = time.time()
+    filtered_pred = box_line_optimized_parallel(predictions)
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
+
+    output_path_box = show_box(im, predictions, t_start)
+    output_path_line = show_line(im, predictions, t_start)
+    show_predict(im, filtered_pred, t_start)
+
+    # combined_image_path = "combined_result.png"
+    # combine_images(
+    #     [output_path_boxandline, output_path_box, output_path_line],
+    #     titles=["Box and Line", "Box", "Line"],
+    #     output_path=combined_image_path
+    # )
+
+    end_time = time.time()
+    print(f'Total time: {end_time - start_time:.2f} seconds')
+
+
+def combine_images(image_paths: list, titles: list, output_path: str):
+    """将多个图像合并为一张图片"""
+    fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
+    for ax, img_path, title in zip(axes, image_paths, titles):
+        ax.imshow(plt.imread(img_path))
+        ax.set_title(title)
+        ax.axis("off")
+    plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
+    plt.close()
+
+
+def show_box(im, predictions, t_start):
+    """绘制边界框并保存结果"""
+    boxes = predictions[0]['boxes'].cpu().numpy()
+    box_scores = predictions[0]['scores'].cpu().numpy()
+    colors = get_colors()
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for idx, (box, score) in enumerate(zip(boxes, box_scores)):
+        if score < 0.7:
+            continue
+        x0, y0, x1, y1 = box
+        ax.add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
+    t_end = time.time()
+    print(f'show_box used: {t_end - t_start:.2f} seconds')
+    plt.show()
+    output_path = "temp_result_box.png"
+    save_plot(output_path)
+    return output_path
+
+
+def show_line(im, predictions, t_start):
+    """绘制线段并保存结果"""
+    lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(im)
+    for (a, b), s in zip(nlines, nscores):
+        if s < 0.9:
+            continue
+        ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    t_end = time.time()
+    print(f'show_line used: {t_end - t_start:.2f} seconds')
+    plt.show()
+    output_path = "temp_result_line.png"
+    save_plot(output_path)
+    return output_path
+
+
+# def show_predict(im, filtered_pred, t_start):
+#     colors = get_colors()
+#     fig, ax = plt.subplots(figsize=(10, 10))
+#     ax.imshow(im)
+#     for idx, pred in enumerate(filtered_pred):
+#         boxes = pred['boxes'].cpu().numpy()
+#         box_scores = pred['scores'].cpu().numpy()
+#         lines = pred['line'].cpu().numpy()
+#         line_scores = pred['line_score'].cpu().numpy()
+#         diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+#         nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
+#         for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
+#             if box_score < 0.7 or line_score < 0.9:
+#                 continue
+#
+#             if line is None or len(line) == 0:
+#                 continue
+#
+#             x0, y0, x1, y1 = box
+#             a, b = line
+#             color = colors[(idx + box_idx) % len(colors)]  # 每个边界框分配一个唯一颜色
+#             ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
+#             ax.scatter(a[1], a[0], c=color, s=10)
+#             ax.scatter(b[1], b[0], c=color, s=10)
+#             ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+#     t_end = time.time()
+#     print(f'show_predict used: {t_end - t_start:.2f} seconds')
+#     plt.show()
+#     output_path = "temp_result.png"
+#     save_plot(output_path)
+#     return output_path
+def show_predict(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    print(imgs.shape)
+    # im = imgs.permute(1, 2, 0)  # 处理为 [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(imgs))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0.7 or line_score >= 0.9:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+
+    plt.show()
+
+
+if __name__ == "__main__":
+    predict(r'C:\Users\m2337\Desktop\p\22.png')

+ 3 - 18
models/line_detect/main_lm_0223.py

@@ -139,7 +139,7 @@ def process_box(box, lines, scores):
         else:
             valid_lines.append([[0.0,0.0],[0.0,0.0]])
             valid_scores.append(0.0)
-        # print(f'valid_lines:{valid_lines}')
+        print(f'valid_lines:{valid_lines}')
         # print(f'valid_scores:{valid_scores}')
     return valid_lines, valid_scores
 
@@ -293,22 +293,7 @@ def show_predict(im, pred, t_start):
             ax.scatter(b[1], b[0], c='#871F78', s=10)
             ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
             idx = idx + 1
-    # for idx, pred in enumerate(filtered_pred):
-    #     boxes = pred['boxes'].cpu().numpy()
-    #     box_scores = pred['scores'].cpu().numpy()
-    #     lines = pred['line'].cpu().numpy()
-    #     line_scores = pred['line_score'].cpu().numpy()
-    #     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    #     nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
-    #     for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
-    #         if box_score > 0.7 and line_score > 0.9:
-    #             x0, y0, x1, y1 = box
-    #             a, b = line
-    #             color = colors[(idx + box_idx) % len(colors)]  # ?????????????
-    #             ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
-    #             ax.scatter(a[1], a[0], c=color, s=10)
-    #             ax.scatter(b[1], b[0], c=color, s=10)
-    #             ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
+
     t_end = time.time()
     print(f'show_predict used: {t_end - t_start:.2f} seconds')
     output_path = "temp_result.png"
@@ -317,4 +302,4 @@ def show_predict(im, pred, t_start):
 
 if __name__ == "__main__":
     import uvicorn
-    uvicorn.run(app, host="0.0.0.0", port=8001)
+    uvicorn.run(app, host="192.168.50.100", port=808, log_level="debug")

+ 179 - 26
models/line_detect/postprocess.py

@@ -1,3 +1,4 @@
+import os
 import time
 
 import torch
@@ -7,6 +8,8 @@ from torchvision import transforms
 
 from models.wirenet.postprocess import postprocess
 
+from datetime import datetime
+
 
 def box_line(pred):
     '''
@@ -188,7 +191,8 @@ def show_line(imgs, pred, t_start):
 
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
-
+    print(f'lines num:{len(line)}')
+    #
     # count = np.sum(line_score > 0.9)
     # print(f'draw line number:{count}')
 
@@ -197,8 +201,8 @@ def show_line(imgs, pred, t_start):
     ax.imshow(np.array(im))
 
     for idx, (a, b) in enumerate(line):
-        if line_score[idx] < 0.9:
-            continue
+        # if line_score[idx] < 0.7:
+        #     continue
         ax.scatter(a[1], a[0], c='#871F78', s=2)
         ax.scatter(b[1], b[0], c='#871F78', s=2)
         ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
@@ -281,7 +285,55 @@ def show_box(imgs, pred, t_start):
 
 
 # 将show_line与show_box合并,传入参数确定显示框还是线  都不显示,输出原图
-def show_box_or_line(imgs, pred, show_line=False, show_box=False):
+# def show_box_or_line(imgs, pred, show_line=False, show_box=False):
+#     col = [
+#         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+#         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+#         '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+#         '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+#         '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+#         '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+#         '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+#         '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+#         '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+#         '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+#         '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+#         '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+#         '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+#         '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+#         '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+#         '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+#         '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+#         '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+#         '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+#         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+#     ]
+#     # print(len(col))
+#     im = imgs.permute(1, 2, 0)
+#     boxes = pred[0]['boxes'].cpu().numpy()
+#     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+#
+#     # 可视化预测结
+#     fig, ax = plt.subplots(figsize=(10, 10))
+#     ax.imshow(np.array(im))
+#
+#     if show_box:
+#         for idx, box in enumerate(boxes):
+#             x0, y0, x1, y1 = box
+#             ax.add_patch(
+#                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+#
+#     if show_line:
+#         for idx, (a, b) in enumerate(line):
+#             ax.scatter(a[1], a[0], c='#871F78', s=2)
+#             ax.scatter(b[1], b[0], c='#871F78', s=2)
+#             ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+#
+#     plt.show()
+
+
+# 将show_line与show_box合并,传入参数确定显示框还是线  一起画
+def show_box_and_line(imgs, pred, show_line=False, show_box=False):
     col = [
         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
@@ -310,27 +362,43 @@ def show_box_or_line(imgs, pred, show_line=False, show_box=False):
     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
 
     # 可视化预测结
-    fig, ax = plt.subplots(figsize=(10, 10))
-    ax.imshow(np.array(im))
+    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
 
     if show_box:
+        axs[0].imshow(np.array(im))
         for idx, box in enumerate(boxes):
             x0, y0, x1, y1 = box
-            ax.add_patch(
+            axs[0].add_patch(
                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+        axs[0].set_title('Boxes')
 
     if show_line:
+        axs[1].imshow(np.array(im))
         for idx, (a, b) in enumerate(line):
-            ax.scatter(a[1], a[0], c='#871F78', s=2)
-            ax.scatter(b[1], b[0], c='#871F78', s=2)
-            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+            axs[1].scatter(a[1], a[0], c='#871F78', s=2)
+            axs[1].scatter(b[1], b[0], c='#871F78', s=2)
+            axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        axs[1].set_title('Lines')
 
+    # 调整子图之间的距离,防止标题和标签重叠
+    plt.tight_layout()
     plt.show()
 
 
-# 将show_line与show_box合并,传入参数确定显示框还是线  一起画
-def show_box_and_line(imgs, pred, show_line=False, show_box=False):
-    col = [
+def set_thresholds(threshold):
+    if isinstance(threshold, list):
+        if len(threshold) != 2:
+            raise ValueError("Threshold list must contain exactly two elements.")
+        a, b = threshold
+    elif isinstance(threshold, (int, float)):
+        a = b = threshold
+    else:
+        raise TypeError("Threshold must be either a list of two numbers or a single number.")
+
+    return a, b
+
+def color():
+    return  [
         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
         '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
@@ -352,30 +420,115 @@ def show_box_and_line(imgs, pred, show_line=False, show_box=False):
         '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
     ]
-    # print(len(col))
+
+def show_all(imgs, pred, threshold, save_path, show):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
     im = imgs.permute(1, 2, 0)
+
     boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
+
+    fig, axs = plt.subplots(1, 3, figsize=(10, 10))
+
+    axs[0].imshow(np.array(im))
+    for idx, box in enumerate(boxes):
+        if box_scores[idx] < box_th:
+            continue
+        x0, y0, x1, y1 = box
+        axs[0].add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+    axs[0].set_title('Boxes')
+
+    axs[1].imshow(np.array(im))
+    for idx, (a, b) in enumerate(line):
+        if line_score[idx] < line_th:
+            continue
+        axs[1].scatter(a[1], a[0], c='#871F78', s=2)
+        axs[1].scatter(b[1], b[0], c='#871F78', s=2)
+        axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    axs[1].set_title('Lines')
+
+    axs[2].imshow(np.array(im))
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+    idx = 0
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0.7 or line_score >= 0.9:
+            axs[2].add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            axs[2].scatter(a[1], a[0], c='#871F78', s=10)
+            axs[2].scatter(b[1], b[0], c='#871F78', s=10)
+            axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    axs[2].set_title('Boxes and Lines')
+
+    if save_path:
+        save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+        plt.savefig(save_path)
+        print(f"Saved result image to {save_path}")
+
+    if show:
+        # 调整子图之间的距离,防止标题和标签重叠
+        plt.tight_layout()
+        plt.show()
+
+
+def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)
+
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
 
     # 可视化预测结
-    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
 
     if show_box:
-        axs[0].imshow(np.array(im))
         for idx, box in enumerate(boxes):
+            if box_scores[idx] < box_th:
+                continue
             x0, y0, x1, y1 = box
-            axs[0].add_patch(
+            ax.add_patch(
                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
-        axs[0].set_title('Boxes')
+        if save_path:
+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+            plt.savefig(save_path)
+            print(f"Saved result image to {save_path}")
+
 
     if show_line:
-        axs[1].imshow(np.array(im))
         for idx, (a, b) in enumerate(line):
-            axs[1].scatter(a[1], a[0], c='#871F78', s=2)
-            axs[1].scatter(b[1], b[0], c='#871F78', s=2)
-            axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
-        axs[1].set_title('Lines')
+            if line_score[idx] < line_th:
+                continue
+            ax.scatter(a[1], a[0], c='#871F78', s=2)
+            ax.scatter(b[1], b[0], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        if save_path:
+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
 
-    # 调整子图之间的距离,防止标题和标签重叠
-    plt.tight_layout()
-    plt.show()
+            plt.savefig(save_path)
+            print(f"Saved result image to {save_path}")
+
+
+    plt.show()

+ 18 - 35
models/line_detect/predict2.py

@@ -27,28 +27,28 @@ def load_best_model(model, save_path, device):
         model.load_state_dict(checkpoint['model_state_dict'])
         # if optimizer is not None:
         #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-        epoch = checkpoint['epoch']
-        loss = checkpoint['loss']
-        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        # epoch = checkpoint['epoch']
+        # loss = checkpoint['loss']
+        # print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
     else:
         print(f"No saved model found at {save_path}")
     return model
 
 
-def box_line_(imgs, pred):
+def box_line_(imgs, pred, length=False):    # 默认置信度
     im = imgs.permute(1, 2, 0).cpu().numpy()
     line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
     line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
 
-    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
     for idx, box_ in enumerate(pred[0:-1]):
         box = box_['boxes']  # 是一个tensor
-        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
-        score = pred[-1]['wires']['score'][idx]
-        #
+        # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        # score = pred[-1]['wires']['score'][idx]
+
         # diag = (512 ** 2 + 512 ** 2) ** 0.5
-        # lines, scores = postprocess(line, score, diag * 0.01, 0, False)
+        # line, score = postprocess(line, score, diag * 0.01, 0, False)
 
         line_ = []
         score_ = []
@@ -63,6 +63,12 @@ def box_line_(imgs, pred):
                         line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
                         line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
 
+                    # # 计算线段长度
+                    # length = np.linalg.norm(line[j][0] - line[j][1])
+                    # if length > score_max:
+                    #     tmp = line[j]
+                    #     score_max = score[j]
+
                     if score[j] > score_max:
                         tmp = line[j]
                         score_max = score[j]
@@ -153,29 +159,6 @@ def predict(pt_path, model, img):
         predictions = model([img_.to(device)])
         # print(predictions)
 
-    # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
-    # scores = predictions[-1]['wires']['score'][0].cpu().numpy() / 128 * 512
-    # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
-    # nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
-    # print(len(nlines))
-
-    # arr = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
-    # unique_subarrays = set()
-    #
-    # for i in range(arr.shape[0]):
-    #     for j in range(arr.shape[1]):
-    #         subarray = arr[i, j]
-    #         # 确保 subarray 是一个二维数组
-    #         if subarray.shape != (2,):
-    #             raise ValueError(f"Unexpected shape of subarray at index [{i}, {j}]: {subarray.shape}, expected (2,)")
-    #
-    #         subarray_tuple = tuple(subarray.tolist())
-    #         unique_subarrays.add(subarray_tuple)
-    #
-    # # 计算唯一子数组的数量
-    # num_unique_subarrays = len(unique_subarrays)
-    # print(f"共有 {num_unique_subarrays} 个不同的 [2, 2] 子数组")
-
     # show_line_optimized(img_, predictions, t_start)   # 只画线
     show_line(img_, predictions, t_start)
     # show_box(img_, predictions, t_start)   # 只画kuang
@@ -197,10 +180,10 @@ if __name__ == '__main__':
     model = linenet_resnet50_fpn().to(device)
     # pt_path = r'D:\python\PycharmProjects\20250214\weight\resnet50_best_e100.pth'
     # pt_path = r'D:\python\PycharmProjects\linenet_wts\r50fpn_wts_e350\best.pth'
-    pt_path = r'D:\python\PycharmProjects\20250214\weight\linenet_wts\r50fpn_wts_e350\best.pth'
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\merged_model_weights.pth'
     # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-43-13_SaveImage.png'  # 工件图
     # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
-    img_path = r'C:\Users\m2337\Desktop\p\2025-01-03-09-34-32_SaveImage_adjust_brightness_contrast.jpg'
+    img_path = r'C:\Users\m2337\Desktop\p\112941.jpg'
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')

+ 371 - 0
models/line_detect/predict_0226.py

@@ -0,0 +1,371 @@
+import time
+
+import skimage
+
+# from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \
+#     show_line_optimized, show_line, show_all
+import os
+
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision import transforms
+
+# from models.wirenet.postprocess import postprocess
+from models.wirenet.postprocess import postprocess
+from rtree import index
+
+from datetime import datetime
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def load_best_model(model, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        # if optimizer is not None:
+        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model
+
+
+def box_line_(imgs, pred, length=False):  # 默认置信度
+    im = imgs.permute(1, 2, 0).cpu().numpy()
+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        # score = pred[-1]['wires']['score'][idx]
+
+        # diag = (512 ** 2 + 512 ** 2) ** 0.5
+        # line, score = postprocess(line, score, diag * 0.01, 0, False)
+
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    # # 计算线段长度
+                    # length = np.linalg.norm(line[j][0] - line[j][1])
+                    # if length > score_max:
+                    #     tmp = line[j]
+                    #     score_max = score[j]
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+            score_.append(score_max)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
+    return pred
+
+
+def box_line_optimized(pred):
+    # 创建R-tree索引
+    idx = index.Index()
+
+    # 将所有线段添加到R-tree中
+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2]
+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500]
+
+    # 提取并处理所有线段
+    for idx_line in range(lines.shape[1]):  # 遍历2500条线段
+        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例
+        x_min = float(min(line_tensor[0][0], line_tensor[1][0]))
+        y_min = float(min(line_tensor[0][1], line_tensor[1][1]))
+        x_max = float(max(line_tensor[0][0], line_tensor[1][0]))
+        y_max = float(max(line_tensor[0][1], line_tensor[1][1]))
+        idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256)))
+
+    for idx_box, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            # 获取与当前box可能相交的所有线段
+            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3])))
+
+            for j in possible_matches:
+                line_j = lines[0, j].cpu().numpy() / 128 * 512
+                if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y
+                        line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
+                        line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
+                        line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
+
+                    if scores[j] > score_max:
+                        tmp = line_j
+                        score_max = scores[j]
+
+            line_.append(tmp)
+            score_.append(score_max)
+
+        processed_list = torch.tensor(line_)
+        pred[idx_box]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx_box]['line_score'] = processed_s_list
+
+    return pred
+
+
+def set_thresholds(threshold):
+    if isinstance(threshold, list):
+        if len(threshold) != 2:
+            raise ValueError("Threshold list must contain exactly two elements.")
+        a, b = threshold
+    elif isinstance(threshold, (int, float)):
+        a = b = threshold
+    else:
+        raise TypeError("Threshold must be either a list of two numbers or a single number.")
+
+    return a, b
+
+def color():
+    return  [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+
+def show_all(imgs, pred, threshold, save_path, show):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)
+
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
+
+    fig, axs = plt.subplots(1, 3, figsize=(10, 10))
+
+    axs[0].imshow(np.array(im))
+    for idx, box in enumerate(boxes):
+        if box_scores[idx] < box_th:
+            continue
+        x0, y0, x1, y1 = box
+        axs[0].add_patch(
+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+    axs[0].set_title('Boxes')
+
+    axs[1].imshow(np.array(im))
+    for idx, (a, b) in enumerate(line):
+        if line_score[idx] < line_th:
+            continue
+        axs[1].scatter(a[1], a[0], c='#871F78', s=2)
+        axs[1].scatter(b[1], b[0], c='#871F78', s=2)
+        axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+    axs[1].set_title('Lines')
+
+    axs[2].imshow(np.array(im))
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+    idx = 0
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= 0.7 or line_score >= 0.9:
+            axs[2].add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            axs[2].scatter(a[1], a[0], c='#871F78', s=10)
+            axs[2].scatter(b[1], b[0], c='#871F78', s=10)
+            axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    axs[2].set_title('Boxes and Lines')
+
+    if save_path:
+        save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+        plt.savefig(save_path)
+        print(f"Saved result image to {save_path}")
+
+    if show:
+        # 调整子图之间的距离,防止标题和标签重叠
+        plt.tight_layout()
+        plt.show()
+
+def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)
+
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+
+    if show_box:
+        for idx, box in enumerate(boxes):
+            if box_scores[idx] < box_th:
+                continue
+            x0, y0, x1, y1 = box
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+        if save_path:
+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+            plt.savefig(save_path)
+            print(f"Saved result image to {save_path}")
+
+
+    if show_line:
+        for idx, (a, b) in enumerate(line):
+            if line_score[idx] < line_th:
+                continue
+            ax.scatter(a[1], a[0], c='#871F78', s=2)
+            ax.scatter(b[1], b[0], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
+        if save_path:
+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
+            os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+            plt.savefig(save_path)
+            print(f"Saved result image to {save_path}")
+
+
+    plt.show()
+
+
+def show_predict(imgs, pred, threshold, t_start):
+    col = color()
+    box_th, line_th = set_thresholds(threshold)
+    im = imgs.permute(1, 2, 0)  # 处理为 [512, 512, 3]
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        # 框中无线的跳过
+        if np.array_equal(line, tmp):
+            continue
+        a, b = line
+        if box_score >= box_th or line_score >= line_th:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=10)
+            ax.scatter(b[1], b[0], c='#871F78', s=10)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+
+    plt.show()
+
+
+def predict(pt_path, model, img, type=0, threshold=0.5, save_path=None, show=False):
+    model = load_best_model(model, pt_path, device)
+
+    model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)  # [3, 512, 512]
+
+    # 将图像调整为512x512大小
+    t_start = time.time()
+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3]
+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3)
+    img_ = torch.tensor(im_resized).permute(2, 0, 1)
+    t_end = time.time()
+    print(f'switch img used:{t_end - t_start}')
+
+    with torch.no_grad():
+        predictions = model([img_.to(device)])
+        # print(predictions)
+
+    t_start = time.time()
+    pred = box_line_(img_, predictions)  # 线框匹配
+    t_end = time.time()
+    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds')
+
+    if type == 0:
+        show_all(img_, pred, threshold, save_path=True, show=True)
+    elif type == 1:
+        show_box_or_line(img_, predictions, threshold, save_path=True, show_line=True)  # 参数确定画什么
+    elif type == 2:
+        show_box_or_line(img_, predictions, threshold, save_path=True, show_box=True)  # 参数确定画什么
+    elif type == 3:
+        show_predict(img_, pred, threshold, t_start)
+
+
+if __name__ == '__main__':
+    t_start = time.time()
+    print(f'start to predict:{t_start}')
+    model = linenet_resnet50_fpn().to(device)
+    pt_path = r'D:\python\PycharmProjects\20250214\weight\best.pth'
+    img_path = r'C:\Users\m2337\Desktop\p\20250226142919.png'
+    # predict(pt_path, model, img_path)
+
+    predict(pt_path, model, img_path, type=2, threshold=0.5, save_path=None, show=False)
+
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')

+ 47 - 3
models/line_detect/roi_heads.py

@@ -1,3 +1,4 @@
+import time
 from typing import Dict, List, Optional, Tuple
 
 import torch
@@ -10,6 +11,8 @@ import libs.vision_libs.models.detection._utils as det_utils
 
 from collections import OrderedDict
 
+from models.wirenet.postprocess import postprocess
+
 
 def l2loss(input, target):
     return ((target - input) ** 2).mean(2).mean(1)
@@ -146,7 +149,6 @@ def line_vectorizer_loss(result, x, ys, idx, jcs, n_batch, ps, n_out_line, n_out
     return result
 
 
-
 def wirepoint_head_line_loss(targets, output, x, y, idx, loss_weight):
     # output, feature: head返回结果
     # x, y, idx : line中间生成结果
@@ -257,6 +259,48 @@ def wirepoint_inference(input, idx, jcs, n_batch, ps, n_out_line, n_out_junc):
     return result
 
 
+def wirepoint_line_target_loss(input, idx, n_batch, ps, n_out_line, targets):
+    # result = {}
+    # result["wires"] = {}
+    p = torch.cat(ps)
+    s = torch.sigmoid(input)
+    b = s > 0.5
+    lines = []
+    score = []
+    # print(f"n_batch:{n_batch}")
+    for i in range(n_batch):
+        # print(f"idx:{idx}")
+        p0 = p[idx[i]: idx[i + 1]]
+        s0 = s[idx[i]: idx[i + 1]]
+        mask = b[idx[i]: idx[i + 1]]
+        p0 = p0[mask]
+        s0 = s0[mask]
+        if len(p0) == 0:
+            lines.append(torch.zeros([1, n_out_line, 2, 2], device=p.device))
+            score.append(torch.zeros([1, n_out_line], device=p.device))
+        else:
+            arg = torch.argsort(s0, descending=True)
+            p0, s0 = p0[arg], s0[arg]
+            lines.append(p0[None, torch.arange(n_out_line) % len(p0)])
+            score.append(s0[None, torch.arange(n_out_line) % len(s0)])
+
+    line_num_loss = []
+    line_maps = [sum(t["wires"]["lpre_label"]) for t in targets]
+    start = time.time()
+    diag = (512 ** 2 + 512 ** 2) ** 0.5
+    line1 = torch.cat(lines).cpu().numpy()
+    score1 = torch.cat(score).cpu().detach().numpy()
+    for i in range(line1.shape[0]):
+        line_tmp = line1[i] / 128 * 512
+        score_tmp = score1[i]
+        lines2, _ = postprocess(line_tmp, score_tmp, diag * 0.01, 0, False)
+        line_num_loss.append(len(lines2)-line_maps[i])
+    end = time.time()
+    print(end - start)
+
+    return sum(line_num_loss) / len(line_num_loss)
+
+
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
     # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
     """
@@ -1071,7 +1115,9 @@ class RoIHeads(nn.Module):
 
             if self.training:
                 rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                # line_target_loss = wirepoint_line_target_loss(x, idx, n_batch, ps, n_out_line, targets)  # 线的数量差
                 loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
+                # loss_wirepoint["loss_wirepoint"]["line_target_loss"] = line_target_loss   # 线数量差损失
             else:
 
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
@@ -1082,8 +1128,6 @@ class RoIHeads(nn.Module):
             pass
             # print('has not line_head')
 
-
-
         if self.has_mask():
             mask_proposals = [p["boxes"] for p in result]
             if self.training:

+ 166 - 0
models/line_detect/static/index.html

@@ -0,0 +1,166 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>LineDetect Alpha_v1.02</title>
+    <style>
+        body {
+            font-family: Arial, sans-serif;
+            margin: 0;
+            padding: 0;
+            display: flex;
+            justify-content: center;
+            align-items: center;
+            height: 100vh;
+            background-color: #f0f4f8;
+        }
+        .container {
+            text-align: center;
+            width: 80%;
+            max-width: 2000px;
+            background-color: #ffffff;
+            border-radius: 8px;
+            padding: 40px;
+            box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
+            overflow: auto; /* ?????? */
+        }
+        h1 {
+            color: #333;
+            font-size: 24px;
+            margin-bottom: 20px;
+        }
+        h2.subtitle {
+            color: #666;
+            font-size: 18px;
+            margin-bottom: 20px;
+        }
+        input[type="file"] {
+            margin-bottom: 20px;
+            padding: 10px;
+            font-size: 16px;
+            border: 1px solid #ccc;
+            border-radius: 4px;
+        }
+        button {
+            padding: 12px 20px;
+            background-color: #4CAF50;
+            color: white;
+            border: none;
+            border-radius: 4px;
+            font-size: 16px;
+            cursor: pointer;
+            transition: background-color 0.3s;
+        }
+        button:hover {
+            background-color: #45a049;
+        }
+        .result {
+            visibility: hidden; /* ??????? */
+            display: flex;
+            flex-direction: column;
+            align-items: center;
+            margin-top: 40px;
+            max-height: 60vh; /* ?????? */
+            overflow-y: auto; /* ??????? */
+        }
+        .result img {
+            max-width: 100%; /* ????????? */
+            height: auto; /* ??????????? */
+            object-fit: cover;
+            border-radius: 8px;
+            box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
+            margin-bottom: 20px;
+        }
+        .uploaded-image {
+            width: 512px; /* ???? */
+            height: 512px; /* ???? */
+            object-fit: cover; /* ???????????????? */
+            border-radius: 8px;
+            box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
+        }
+    </style>
+</head>
+<body>
+    <div class="container">
+        <h1>LineDetect Alpha_v1.02</h1> <!-- ???? -->
+        <h2 class="subtitle">请上传图片</h2> <!-- ????? -->
+        <input type="file" id="fileInput" accept="image/*" />
+        <button onclick="uploadImage()">上传</button> <!-- ?????? -->
+        <div class="result">
+            <div>
+                <h2>原图</h2>
+                <img id="uploadedImage" class="uploaded-image" src="" alt="Uploaded Image" style="display: none;" />
+            </div>
+            <div>
+                <h2>结果</h2>
+                <img id="resultImage" src="" alt="Prediction Result" style="display: none;" />
+            </div>
+        </div>
+    </div>
+    <script>
+        async function uploadImage() {
+            const fileInput = document.getElementById('fileInput');
+            const file = fileInput.files[0];
+            if (!file) {
+                alert("???????????");
+                return;
+            }
+
+            // ?? canvas ?????? 512x512
+            const resizedImageBlob = await resizeImage(file, 512, 512);
+
+            // ???????
+            const uploadedImage = document.getElementById('uploadedImage');
+            uploadedImage.src = URL.createObjectURL(resizedImageBlob);
+            uploadedImage.style.display = 'block';
+
+            // ???????????
+            const formData = new FormData();
+            formData.append("file", resizedImageBlob, file.name); // ??????
+            const response = await fetch('http://192.168.50.100:808/predict', {
+                method: 'POST',
+                body: formData,
+            });
+            if (response.ok) {
+                const resultBlob = await response.blob();
+                const resultImage = document.getElementById('resultImage');
+                resultImage.src = URL.createObjectURL(resultBlob);
+                resultImage.style.display = 'block';
+                document.querySelector('.result').style.visibility = 'visible'; // ????
+            } else {
+                alert("?????");
+            }
+        }
+
+        /**
+         * ?? canvas ????
+         * @param {File} file - ??????
+         * @param {number} width - ????
+         * @param {number} height - ????
+         * @returns {Promise<Blob>} - ???????? Blob
+         */
+        async function resizeImage(file, width, height) {
+            return new Promise((resolve, reject) => {
+                const img = new Image();
+                img.onload = () => {
+                    const canvas = document.createElement('canvas');
+                    canvas.width = width;
+                    canvas.height = height;
+                    const ctx = canvas.getContext('2d');
+
+                    // ????????????
+                    ctx.drawImage(img, 0, 0, width, height);
+
+                    // ? canvas ??? Blob
+                    canvas.toBlob(blob => {
+                        resolve(blob);
+                    }, file.type || 'image/jpeg');
+                };
+                img.onerror = reject;
+                img.src = URL.createObjectURL(file);
+            });
+        }
+    </script>
+</body>
+</html>

BIN
models/line_detect/temp_result_box.png


BIN
models/line_detect/temp_result_line.png


BIN
models/line_detect/uploaded_images/020.png


BIN
models/line_detect/uploaded_images/20.jpg


BIN
models/line_detect/uploaded_images/21.jpg


BIN
models/line_detect/uploaded_images/49.jpg


BIN
models/line_detect/uploaded_images/9.jpg


+ 166 - 0
static/index.html

@@ -0,0 +1,166 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+    <meta charset="UTF-8">
+    <meta name="viewport" content="width=device-width, initial-scale=1.0">
+    <title>LineDetect Alpha_v1.02</title>
+    <style>
+        body {
+            font-family: Arial, sans-serif;
+            margin: 0;
+            padding: 0;
+            display: flex;
+            justify-content: center;
+            align-items: center;
+            height: 100vh;
+            background-color: #f0f4f8;
+        }
+        .container {
+            text-align: center;
+            width: 80%;
+            max-width: 2000px;
+            background-color: #ffffff;
+            border-radius: 8px;
+            padding: 40px;
+            box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
+            overflow: auto; /* ?????? */
+        }
+        h1 {
+            color: #333;
+            font-size: 24px;
+            margin-bottom: 20px;
+        }
+        h2.subtitle {
+            color: #666;
+            font-size: 18px;
+            margin-bottom: 20px;
+        }
+        input[type="file"] {
+            margin-bottom: 20px;
+            padding: 10px;
+            font-size: 16px;
+            border: 1px solid #ccc;
+            border-radius: 4px;
+        }
+        button {
+            padding: 12px 20px;
+            background-color: #4CAF50;
+            color: white;
+            border: none;
+            border-radius: 4px;
+            font-size: 16px;
+            cursor: pointer;
+            transition: background-color 0.3s;
+        }
+        button:hover {
+            background-color: #45a049;
+        }
+        .result {
+            visibility: hidden; /* ??????? */
+            display: flex;
+            flex-direction: column;
+            align-items: center;
+            margin-top: 40px;
+            max-height: 60vh; /* ?????? */
+            overflow-y: auto; /* ??????? */
+        }
+        .result img {
+            max-width: 100%; /* ????????? */
+            height: auto; /* ??????????? */
+            object-fit: cover;
+            border-radius: 8px;
+            box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
+            margin-bottom: 20px;
+        }
+        .uploaded-image {
+            width: 512px; /* ???? */
+            height: 512px; /* ???? */
+            object-fit: cover; /* ???????????????? */
+            border-radius: 8px;
+            box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1);
+        }
+    </style>
+</head>
+<body>
+    <div class="container">
+        <h1>LineDetect Alpha_v1.02</h1> <!-- ???? -->
+        <h2 class="subtitle">请上传图片</h2> <!-- ????? -->
+        <input type="file" id="fileInput" accept="image/*" />
+        <button onclick="uploadImage()">上传</button> <!-- ?????? -->
+        <div class="result">
+            <div>
+                <h2>原图</h2>
+                <img id="uploadedImage" class="uploaded-image" src="" alt="Uploaded Image" style="display: none;" />
+            </div>
+            <div>
+                <h2>结果</h2>
+                <img id="resultImage" src="" alt="Prediction Result" style="display: none;" />
+            </div>
+        </div>
+    </div>
+    <script>
+        async function uploadImage() {
+            const fileInput = document.getElementById('fileInput');
+            const file = fileInput.files[0];
+            if (!file) {
+                alert("???????????");
+                return;
+            }
+
+            // ?? canvas ?????? 512x512
+            const resizedImageBlob = await resizeImage(file, 512, 512);
+
+            // ???????
+            const uploadedImage = document.getElementById('uploadedImage');
+            uploadedImage.src = URL.createObjectURL(resizedImageBlob);
+            uploadedImage.style.display = 'block';
+
+            // ???????????
+            const formData = new FormData();
+            formData.append("file", resizedImageBlob, file.name); // ??????
+            const response = await fetch('http://0.0.0.0:808/predict', {
+                method: 'POST',
+                body: formData,
+            });
+            if (response.ok) {
+                const resultBlob = await response.blob();
+                const resultImage = document.getElementById('resultImage');
+                resultImage.src = URL.createObjectURL(resultBlob);
+                resultImage.style.display = 'block';
+                document.querySelector('.result').style.visibility = 'visible'; // ????
+            } else {
+                alert("?????");
+            }
+        }
+
+        /**
+         * ?? canvas ????
+         * @param {File} file - ??????
+         * @param {number} width - ????
+         * @param {number} height - ????
+         * @returns {Promise<Blob>} - ???????? Blob
+         */
+        async function resizeImage(file, width, height) {
+            return new Promise((resolve, reject) => {
+                const img = new Image();
+                img.onload = () => {
+                    const canvas = document.createElement('canvas');
+                    canvas.width = width;
+                    canvas.height = height;
+                    const ctx = canvas.getContext('2d');
+
+                    // ????????????
+                    ctx.drawImage(img, 0, 0, width, height);
+
+                    // ? canvas ??? Blob
+                    canvas.toBlob(blob => {
+                        resolve(blob);
+                    }, file.type || 'image/jpeg');
+                };
+                img.onerror = reject;
+                img.src = URL.createObjectURL(file);
+            });
+        }
+    </script>
+</body>
+</html>