|
@@ -40,6 +40,42 @@ def c(x):
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
|
|
+from PIL import ImageDraw
|
|
|
|
|
+from torchvision.transforms import functional as F
|
|
|
|
|
+import torch
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 由低到高蓝黄红
|
|
|
|
|
+def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
|
|
|
|
|
+ """
|
|
|
|
|
+ 根据得分对线段着色并绘制
|
|
|
|
|
+ :param tensor_image: (3, H, W) uint8 图像
|
|
|
|
|
+ :param lines: (N, 2, 2) 每条线 [ [x1,y1], [x2,y2] ]
|
|
|
|
|
+ :param scores: (N,) 每条线的得分,范围 [0, 1]
|
|
|
|
|
+ :param width: 线宽
|
|
|
|
|
+ :param cmap: matplotlib colormap 名称,例如 'viridis', 'jet', 'coolwarm'
|
|
|
|
|
+ :return: (3, H, W) uint8 画好线的图像
|
|
|
|
|
+ """
|
|
|
|
|
+ assert tensor_image.dtype == torch.uint8
|
|
|
|
|
+ assert tensor_image.shape[0] == 3
|
|
|
|
|
+ assert lines.shape[0] == scores.shape[0]
|
|
|
|
|
+
|
|
|
|
|
+ # 准备色图
|
|
|
|
|
+ colormap = plt.get_cmap(cmap)
|
|
|
|
|
+ colors = (colormap(scores.cpu().numpy())[:, :3] * 255).astype('uint8') # 去掉 alpha 通道
|
|
|
|
|
+
|
|
|
|
|
+ # 转为 PIL 画图
|
|
|
|
|
+ image_pil = F.to_pil_image(tensor_image)
|
|
|
|
|
+ draw = ImageDraw.Draw(image_pil)
|
|
|
|
|
+
|
|
|
|
|
+ for line, color in zip(lines, colors):
|
|
|
|
|
+ start = tuple(map(float, line[0][:2].tolist()))
|
|
|
|
|
+ end = tuple(map(float, line[1][:2].tolist()))
|
|
|
|
|
+ draw.line([start, end], fill=tuple(color), width=width)
|
|
|
|
|
+
|
|
|
|
|
+ return (F.to_tensor(image_pil) * 255).to(torch.uint8)
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class Trainer(BaseTrainer):
|
|
class Trainer(BaseTrainer):
|
|
|
def __init__(self, model=None, **kwargs):
|
|
def __init__(self, model=None, **kwargs):
|
|
@@ -147,6 +183,10 @@ class Trainer(BaseTrainer):
|
|
|
print(f"No saved model found at {save_path}")
|
|
print(f"No saved model found at {save_path}")
|
|
|
return model, optimizer
|
|
return model, optimizer
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def writer_predict_result(self, img, result, epoch):
|
|
def writer_predict_result(self, img, result, epoch):
|
|
|
img = img.cpu().detach()
|
|
img = img.cpu().detach()
|
|
|
im = img.permute(1, 2, 0) # [512, 512, 3]
|
|
im = img.permute(1, 2, 0) # [512, 512, 3]
|
|
@@ -162,6 +202,19 @@ class Trainer(BaseTrainer):
|
|
|
keypoint_img = draw_keypoints(boxed_image, result['lines'], colors='red', width=3)
|
|
keypoint_img = draw_keypoints(boxed_image, result['lines'], colors='red', width=3)
|
|
|
|
|
|
|
|
self.writer.add_image("z-output", keypoint_img, epoch)
|
|
self.writer.add_image("z-output", keypoint_img, epoch)
|
|
|
|
|
+ print("lines shape:", result['lines'].shape)
|
|
|
|
|
+
|
|
|
|
|
+ # 用自己写的函数画线段
|
|
|
|
|
+ # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
|
|
|
|
|
+ print(f"shape of linescore:{result['liness_scores'].shape}")
|
|
|
|
|
+ scores = result['liness_scores'].mean(dim=1) # shape: [31]
|
|
|
|
|
+
|
|
|
|
|
+ line_image = draw_lines_with_scores((img * 255).to(torch.uint8), result['lines'],scores, width=3, cmap='jet')
|
|
|
|
|
+
|
|
|
|
|
+ self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
|
|
|
def writer_loss(self, losses, epoch, phase='train'):
|
|
def writer_loss(self, losses, epoch, phase='train'):
|
|
|
try:
|
|
try:
|