Rookie 5 mesiacov pred
rodič
commit
70ba0a709c

+ 36 - 0
models/line_detect/show_score.py

@@ -0,0 +1,36 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+def save_color_gradient(cmap_name='jet', width=512, height=50, filename=None):
+    """
+    保存颜色从 0 到 1 渐变的图像
+    :param cmap_name: 色图名,如 'jet', 'viridis'
+    :param width: 渐变图宽度(像素)
+    :param height: 渐变图高度(像素)
+    :param filename: 保存的文件名(可选)
+    """
+    cmap = plt.get_cmap(cmap_name)
+    gradient = np.linspace(0, 1, width).reshape(1, -1)
+    gradient = np.repeat(gradient, height, axis=0)
+
+    plt.figure(figsize=(8, 1.5))
+    plt.imshow(gradient, aspect='auto', cmap=cmap)
+    plt.title(f"Color Gradient (colormap = '{cmap_name}')")
+    plt.xlabel("Score: 0 → 1")
+    plt.yticks([])
+    plt.xticks([0, width//4, width//2, 3*width//4, width], [0.0, 0.25, 0.5, 0.75, 1.0])
+    plt.tight_layout()
+
+    # 默认文件名
+    if filename is None:
+        filename = f"color_gradient_{cmap_name}.png"
+
+    # 保存图像
+    plt.savefig(filename, dpi=150)
+    plt.close()
+
+    print(f"已保存: {os.path.abspath(filename)}")
+
+# 调用:保存 jet 渐变图
+save_color_gradient('jet')

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: /data/share/rlq/datasets/250612
+  datadir: G:\python_ws_g\data\250612
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 2 - 2
models/line_detect/train_demo.py

@@ -10,12 +10,12 @@ if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
     # model=linenet_resnet50_fpn()
-    # model = linenet_resnet18_fpn()
+    model = linedetect_resnet18_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_newresnet50fpn()
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn()
+    # model=linedetect_newresnet18fpn()
 
     model.start_train(cfg='train.yaml')

+ 53 - 0
models/line_detect/trainer.py

@@ -40,6 +40,42 @@ def c(x):
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
+import matplotlib.pyplot as plt
+from PIL import ImageDraw
+from torchvision.transforms import functional as F
+import torch
+
+
+# 由低到高蓝黄红
+def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
+    """
+    根据得分对线段着色并绘制
+    :param tensor_image: (3, H, W) uint8 图像
+    :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
+    :param scores: (N,) 每条线的得分,范围 [0, 1]
+    :param width: 线宽
+    :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
+    :return: (3, H, W) uint8 画好线的图像
+    """
+    assert tensor_image.dtype == torch.uint8
+    assert tensor_image.shape[0] == 3
+    assert lines.shape[0] == scores.shape[0]
+
+    # 准备色图
+    colormap = plt.get_cmap(cmap)
+    colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8')  # 去掉 alpha 通道
+
+    # 转为 PIL 画图
+    image_pil = F.to_pil_image(tensor_image)
+    draw = ImageDraw.Draw(image_pil)
+
+    for line, color in zip(lines, colors):
+        start = tuple(map(float, line[0][:2].tolist()))
+        end = tuple(map(float, line[1][:2].tolist()))
+        draw.line([start, end], fill=tuple(color), width=width)
+
+    return (F.to_tensor(image_pil) * 255).to(torch.uint8)
+
 
 class Trainer(BaseTrainer):
     def __init__(self, model=None, **kwargs):
@@ -147,6 +183,10 @@ class Trainer(BaseTrainer):
             print(f"No saved model found at {save_path}")
         return model, optimizer
 
+
+
+
+
     def writer_predict_result(self, img, result, epoch):
         img = img.cpu().detach()
         im = img.permute(1, 2, 0)  # [512, 512, 3]
@@ -162,6 +202,19 @@ class Trainer(BaseTrainer):
         keypoint_img = draw_keypoints(boxed_image, result['lines'], colors='red', width=3)
 
         self.writer.add_image("z-output", keypoint_img, epoch)
+        print("lines shape:", result['lines'].shape)
+
+        # 用自己写的函数画线段
+        # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
+        print(f"shape of linescore:{result['liness_scores'].shape}")
+        scores = result['liness_scores'].mean(dim=1)  # shape: [31]
+
+        line_image = draw_lines_with_scores((img * 255).to(torch.uint8),  result['lines'],scores, width=3, cmap='jet')
+
+        self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+
+
 
     def writer_loss(self, losses, epoch, phase='train'):
         try: