Browse Source

第3版 box与line对应绘图 没有进行并行运算

xue50 3 months ago
parent
commit
1098a532f7
5 changed files with 312 additions and 11 deletions
  1. 3 1
      .gitignore
  2. 70 10
      models/line_detect/postprocess.py
  3. 130 0
      models/line_detect/predict.py
  4. 93 0
      models/line_detect/predict2.py
  5. 16 0
      readme.md

+ 3 - 1
.gitignore

@@ -28,4 +28,6 @@ checkpoint
 /.ipynb_checkpoints/
 
 __pycache__
-train_results
+train_results
+
+models/line_detect/linenet_wts

+ 70 - 10
models/line_detect/postprocess.py

@@ -1,8 +1,12 @@
+import time
+
 import torch
 import matplotlib.pyplot as plt
 import numpy as np
 from torchvision import transforms
 
+from models.wirenet.postprocess import postprocess
+
 
 def box_line(pred):
     '''
@@ -34,27 +38,33 @@ def box_line(pred):
 
 
 def box_line_(pred):
-    '''
-    形式同pred
-    '''
     for idx, box_ in enumerate(pred[0:-1]):
         box = box_['boxes']  # 是一个tensor
         line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
         score = pred[-1]['wires']['score'][idx]
         line_ = []
+        score_ = []
+
         for i in box:
             score_max = 0.0
             tmp = [[0.0, 0.0], [0.0, 0.0]]
+
             for j in range(len(line)):
-                if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
-                        line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
-                        line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
                     if score[j] > score_max:
                         tmp = line[j]
                         score_max = score[j]
             line_.append(tmp)
+            score_.append(score_max)
         processed_list = torch.tensor(line_)
         pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
     return pred
 
 
@@ -81,7 +91,7 @@ def show_(imgs, pred, epoch, writer):
         '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
     ]
-    print(len(col))
+    # print(len(col))
     im = imgs[0].permute(1, 2, 0)
     boxes = pred[0]['boxes'].cpu().numpy()
     line = pred[0]['line'].cpu().numpy()
@@ -96,9 +106,9 @@ def show_(imgs, pred, epoch, writer):
             plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
 
     for idx, (a, b) in enumerate(line):
-        ax.scatter(a[0], a[1], c=col[99 - idx], s=2)
-        ax.scatter(b[0], b[1], c=col[99 - idx], s=2)
-        ax.plot([a[0], b[0]], [a[1], b[1]], c=col[idx], linewidth=1)
+        ax.scatter(a[1], a[0], c='#871F78', s=2)
+        ax.scatter(b[1], b[0], c='#871F78', s=2)
+        ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
 
     # 将Matplotlib图像转换为Tensor
     fig.canvas.draw()
@@ -109,3 +119,53 @@ def show_(imgs, pred, epoch, writer):
 
     writer.add_image("all", img2, epoch)
 
+
+def show_predict(imgs, pred, t_start):
+    col = [
+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
+    ]
+    print(len(col))
+    im = imgs.permute(1, 2, 0)
+    boxes = pred[0]['boxes'].cpu().numpy()
+    box_scores = pred[0]['scores'].cpu().numpy()
+    lines = pred[0]['line'].cpu().numpy()
+    line_scores = pred[0]['line_score'].cpu().numpy()
+
+    # 可视化预测结
+    fig, ax = plt.subplots(figsize=(10, 10))
+    ax.imshow(np.array(im))
+    idx = 0
+
+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
+        x0, y0, x1, y1 = box
+        a, b = line
+        if box_score > 0.7 and line_score > 0.9:
+            ax.add_patch(
+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
+            ax.scatter(a[1], a[0], c='#871F78', s=2)
+            ax.scatter(b[1], b[0], c='#871F78', s=2)
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
+            idx = idx + 1
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')
+
+    plt.show()

+ 130 - 0
models/line_detect/predict.py

@@ -0,0 +1,130 @@
+import os
+
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision import transforms
+
+from models.wirenet.postprocess import postprocess
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def load_best_model(model, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        # if optimizer is not None:
+        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model
+
+
+cmap = plt.get_cmap("jet")
+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
+sm.set_array([])
+
+
+def c(x):
+    return sm.to_rgba(x)
+
+
+def imshow(im):
+    plt.close()
+    plt.tight_layout()
+    plt.imshow(im)
+    plt.colorbar(sm, fraction=0.046)
+    plt.xlim([0, im.shape[0]])
+    plt.ylim([im.shape[0], 0])
+
+
+def show_line(img, pred):
+    im = img.permute(1, 2, 0)
+
+    # 创建图形和坐标轴
+    fig, ax = plt.subplots(figsize=(10, 10))
+    # 绘制原始图像
+    ax.imshow(np.array(im))
+    # 绘制边界框
+    boxes = pred[0]['boxes'].cpu().numpy()
+    boxes_scores = pred[0]['scores'].cpu().numpy()
+
+    # for box in boxes:
+    #     x0, y0, x1, y1 = box
+    #     rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
+    #     ax.add_patch(rect)  # 将矩形添加到 Axes 对象上
+
+    for b, s in zip(boxes, boxes_scores):
+        # print(f'box:{b}, box_score:{s}')
+        if s < 0.7:
+            continue
+        x0, y0, x1, y1 = b
+        rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)
+        ax.add_patch(rect)  # 将矩形添加到 Axes 对象上
+
+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+    H = pred[-1]['wires']
+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+    scores = H["score"][0].cpu().numpy()
+    for i in range(1, len(lines)):
+        if (lines[i] == lines[0]).all():
+            lines = lines[:i]
+            scores = scores[:i]
+            break
+
+    # 后处理线条以去除重叠的线条
+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+    # 根据分数绘制线条
+    for i, t in enumerate([0.9]):
+        for (a, b), s in zip(nlines, nscores):
+            if s < t:
+                continue
+            ax.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)  # 在 Axes 上绘制线条
+            ax.scatter(a[1], a[0], **PLTOPTS)  # 在 Axes 上绘制散点
+            ax.scatter(b[1], b[0], **PLTOPTS)  # 在 Axes 上绘制散点
+
+    # 隐藏坐标轴
+    ax.set_axis_off()
+    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+    plt.margins(0, 0)
+    ax.xaxis.set_major_locator(plt.NullLocator())
+    ax.yaxis.set_major_locator(plt.NullLocator())
+
+    # 显示图像
+    plt.show()
+
+
+def predict(pt_path, model, img):
+    model = load_best_model(model, pt_path, device)
+
+    model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+
+    with torch.no_grad():
+        predictions = model([img_tensor])
+        print(predictions[0])
+
+    show_line(img_tensor, predictions)
+
+
+if __name__ == '__main__':
+    model = linenet_resnet50_fpn()
+    pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
+    # img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
+    img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    predict(pt_path, model, img_path)

+ 93 - 0
models/line_detect/predict2.py

@@ -0,0 +1,93 @@
+import time
+
+from models.line_detect.postprocess import show_predict
+import os
+
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+from models.line_detect.line_net import linenet_resnet50_fpn
+from torchvision import transforms
+
+# from models.wirenet.postprocess import postprocess
+from models.wirenet.postprocess import postprocess
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def load_best_model(model, save_path, device):
+    if os.path.exists(save_path):
+        checkpoint = torch.load(save_path, map_location=device)
+        model.load_state_dict(checkpoint['model_state_dict'])
+        # if optimizer is not None:
+        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        epoch = checkpoint['epoch']
+        loss = checkpoint['loss']
+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+    else:
+        print(f"No saved model found at {save_path}")
+    return model
+
+def box_line_(pred):
+    for idx, box_ in enumerate(pred[0:-1]):
+        box = box_['boxes']  # 是一个tensor
+        line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
+        score = pred[-1]['wires']['score'][idx]
+        line_ = []
+        score_ = []
+
+        for i in box:
+            score_max = 0.0
+            tmp = [[0.0, 0.0], [0.0, 0.0]]
+
+            for j in range(len(line)):
+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
+
+                    if score[j] > score_max:
+                        tmp = line[j]
+                        score_max = score[j]
+            line_.append(tmp)
+            score_.append(score_max)
+        processed_list = torch.tensor(line_)
+        pred[idx]['line'] = processed_list
+
+        processed_s_list = torch.tensor(score_)
+        pred[idx]['line_score'] = processed_s_list
+    return pred
+
+
+def predict(pt_path, model, img):
+    model = load_best_model(model, pt_path, device)
+
+    model.eval()
+
+    if isinstance(img, str):
+        img = Image.open(img).convert("RGB")
+
+    transform = transforms.ToTensor()
+    img_tensor = transform(img)
+
+    with torch.no_grad():
+        predictions = model([img_tensor.to(device)])
+        # print(predictions)
+
+    pred = box_line_(predictions)
+    # print(f'pred:{pred[0]}')
+    show_predict(img_tensor, pred, t_start)
+
+
+if __name__ == '__main__':
+    t_start = time.time()
+    print(f'start to predict:{t_start}')
+    model = linenet_resnet50_fpn().to(device)
+    pt_path = r'D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\linenet_wts\resnet50_best_e8.pth'
+    img_path = f'D:\python\PycharmProjects\data2\images/train/2024-11-27-15-41-38_SaveImage.png'  # 工件图
+    # img_path = f'D:\python\PycharmProjects\data\images/train/00558656_3.png'  # wireframe图
+    predict(pt_path, model, img_path)
+    t_end = time.time()
+    print(f'predict used:{t_end - t_start}')

+ 16 - 0
readme.md

@@ -18,3 +18,19 @@ A100(40G) train edition
 Include objection dectection ,keypoint detection,instance segment detection and line dectection.
 
 
+
+
+
+第3版的修改
+
+
+
+D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\predict.py
+
+加载权重进行训练。设置box和line的阈值,将box和line画在原图上
+
+D:\python\PycharmProjects\lcnn-master\lcnn_\20250212\MultiVisionModels\models\line_detect\predict2.py
+
+加载权重进行训练。设置box和line的阈值,将box和line画在原图上,且一一对应。但没有并行训练
+
+此处将所有一一对应的box与line都画出来,且进行阈值限制,box为0.7,line为0.9