浏览代码

调试新改的line_predict 及损失

RenLiqiang 5 月之前
父节点
当前提交
35d39c6ce5
共有 3 个文件被更改,包括 178 次插入17 次删除
  1. 6 4
      models/line_detect/line_detect.py
  2. 170 11
      models/line_detect/roi_heads.py
  3. 2 2
      models/line_detect/train.yaml

+ 6 - 4
models/line_detect/line_detect.py

@@ -173,7 +173,7 @@ class LineDetect(BaseDetectionNet):
 
         if line_predictor is None:
             keypoint_dim_reduced = 512  # == keypoint_layers[-1]
-            line_predictor = LinePredictor(keypoint_dim_reduced, num_keypoints)
+            line_predictor = LinePredictor(keypoint_dim_reduced)
 
 
         self.roi_heads.line_roi_pool = line_roi_pool
@@ -305,13 +305,13 @@ class LineHeads(nn.Sequential):
 
 
 class LinePredictor(nn.Module):
-    def __init__(self, in_channels, num_keypoints):
+    def __init__(self, in_channels, out_channels=1 ):
         super().__init__()
         input_features = in_channels
         deconv_kernel = 4
         self.kps_score_lowres = nn.ConvTranspose2d(
             input_features,
-            num_keypoints,
+            out_channels,
             deconv_kernel,
             stride=2,
             padding=deconv_kernel // 2 - 1,
@@ -319,10 +319,12 @@ class LinePredictor(nn.Module):
         nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
         nn.init.constant_(self.kps_score_lowres.bias, 0)
         self.up_scale = 2
-        self.out_channels = num_keypoints
+        self.out_channels = out_channels
 
     def forward(self, x):
+        print(f'before kps_score_lowres x:{x.shape}')
         x = self.kps_score_lowres(x)
+        print(f'kps_score_lowres x:{x.shape}')
         return torch.nn.functional.interpolate(
             x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
         )

+ 170 - 11
models/line_detect/roi_heads.py

@@ -1,5 +1,6 @@
 from typing import Dict, List, Optional, Tuple
 
+import matplotlib.pyplot as plt
 import torch
 import torch.nn.functional as F
 import torchvision
@@ -128,6 +129,138 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
     return mask_loss
 
 
+def line_points_to_heatmap(keypoints, rois, heatmap_size):
+    # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
+    print(f'rois:{rois.shape}')
+    print(f'heatmap_size:{heatmap_size}')
+    offset_x = rois[:, 0]
+    offset_y = rois[:, 1]
+    scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
+    scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
+
+    offset_x = offset_x[:, None]
+    offset_y = offset_y[:, None]
+    scale_x = scale_x[:, None]
+    scale_y = scale_y[:, None]
+
+    print(f'keypoints.shape:{keypoints.shape}')
+    # batch_size, num_keypoints, _ = keypoints.shape
+
+    x = keypoints[..., 0]
+    y = keypoints[..., 1]
+
+    # gs=generate_gaussian_heatmaps(x,y,512,1.0)
+    # print(f'gs_heatmap shape:{gs.shape}')
+    #
+    # show_heatmap(gs,'target')
+
+    x_boundary_inds = x == rois[:, 2][:, None]
+    y_boundary_inds = y == rois[:, 3][:, None]
+
+    x = (x - offset_x) * scale_x
+    x = x.floor().long()
+    y = (y - offset_y) * scale_y
+    y = y.floor().long()
+
+    x[x_boundary_inds] = heatmap_size - 1
+    y[y_boundary_inds] = heatmap_size - 1
+    # print(f'heatmaps x:{x}')
+    # print(f'heatmaps y:{y}')
+
+    valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
+    vis = keypoints[..., 2] > 0
+    valid = (valid_loc & vis).long()
+
+    gs_heatmap=generate_gaussian_heatmaps(x,y,heatmap_size,1.0)
+
+    # show_heatmap(gs_heatmap[0],'feature')
+
+    print(f'gs_heatmap:{gs_heatmap.shape}')
+    #
+    # lin_ind = y * heatmap_size + x
+    # print(f'lin_ind:{lin_ind.shape}')
+    # heatmaps = lin_ind * valid
+
+    return gs_heatmap
+
+
+def generate_gaussian_heatmaps(xs, ys, heatmap_size, sigma=2.0, device='cuda'):
+    """
+    为一组点生成并合并高斯热图。
+
+    Args:
+        xs (Tensor): 形状为 (N, 2) 的所有点的 x 坐标
+        ys (Tensor): 形状为 (N, 2) 的所有点的 y 坐标
+        heatmap_size (int): 热图大小 H=W
+        sigma (float): 高斯核标准差
+        device (str): 设备类型 ('cpu' or 'cuda')
+
+    Returns:
+        Tensor: 形状为 (H, W) 的合并后的热图
+    """
+
+    assert xs.shape == ys.shape, "x and y must have the same shape"
+    N = xs.shape[0]
+    print(f'N:{N}')
+
+    # 创建网格
+    grid_y, grid_x = torch.meshgrid(
+        torch.arange(heatmap_size, device=device),
+        torch.arange(heatmap_size, device=device),
+        indexing='ij'
+    )
+
+    # print(f'heatmap_size:{heatmap_size}')
+    # 初始化输出热图
+    combined_heatmap = torch.zeros((N,heatmap_size, heatmap_size), device=device)
+    for i in range(N):
+
+        mu_x1 = xs[i, 0].clamp(0, heatmap_size - 1).item()
+        mu_y1 = ys[i, 0].clamp(0, heatmap_size - 1).item()
+
+        # 计算距离平方
+        dist1 = (grid_x - mu_x1) ** 2 + (grid_y - mu_y1) ** 2
+
+        # 计算高斯分布
+        heatmap1 = torch.exp(-dist1 / (2 * sigma ** 2))
+
+        mu_x2 = xs[i, 1].clamp(0, heatmap_size - 1).item()
+        mu_y2 = ys[i, 1].clamp(0, heatmap_size - 1).item()
+
+        # 计算距离平方
+        dist2 = (grid_x - mu_x2) ** 2 + (grid_y - mu_y2) ** 2
+
+        # 计算高斯分布
+        heatmap2 = torch.exp(-dist2 / (2 * sigma ** 2))
+
+        heatmap=heatmap1+heatmap2
+
+        # 将当前热图累加到结果中
+        combined_heatmap[i]= heatmap
+
+    return combined_heatmap
+
+
+# 显示热图的函数
+def show_heatmap(heatmap, title="Heatmap"):
+    """
+    使用 matplotlib 显示热图。
+
+    Args:
+        heatmap (Tensor): 要显示的热图张量
+        title (str): 图表标题
+    """
+    # 如果在 GPU 上,首先将其移动到 CPU 并转换为 numpy 数组
+    if heatmap.is_cuda:
+        heatmap = heatmap.cpu().numpy()
+    else:
+        heatmap = heatmap.numpy()
+
+    plt.imshow(heatmap, cmap='hot', interpolation='nearest')
+    plt.colorbar()
+    plt.title(title)
+    plt.show()
+
 def keypoints_to_heatmap(keypoints, rois, heatmap_size):
     # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
     offset_x = rois[:, 0]
@@ -158,6 +291,7 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
     vis = keypoints[..., 2] > 0
     valid = (valid_loc & vis).long()
 
+
     lin_ind = y * heatmap_size + x
     heatmaps = lin_ind * valid
 
@@ -298,31 +432,50 @@ def heatmaps_to_keypoints(maps, rois):
 def lines_point_pair_loss(line_logits, proposals, gt_lines, line_matched_idxs):
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
     N, K, H, W = line_logits.shape
+    batch_size=len(proposals)
+    print(f'lines_point_pair_loss line_logits.shape:{line_logits.shape}')
     if H != W:
         raise ValueError(
             f"line_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}"
         )
     discretization_size = H
     heatmaps = []
+    gs_heatmaps=[]
     valid = []
     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_lines, line_matched_idxs):
+        print(f'proposals_per_image:{proposals_per_image.shape}')
         kp = gt_kp_in_image[midx]
-        heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
-        heatmaps.append(heatmaps_per_image.view(-1))
-        valid.append(valid_per_image.view(-1))
+        gs_heatmaps_per_img = line_points_to_heatmap(kp, proposals_per_image, discretization_size)
+        gs_heatmaps.append(gs_heatmaps_per_img)
+        # print(f'heatmaps_per_image:{heatmaps_per_image.shape}')
 
-    line_targets = torch.cat(heatmaps, dim=0)
-    valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
-    valid = torch.where(valid)[0]
+
+        # heatmaps.append(heatmaps_per_image.view(-1))
+
+        # valid.append(valid_per_image.view(-1))
+
+    # line_targets = torch.cat(heatmaps, dim=0)
+    gs_heatmaps=torch.cat(gs_heatmaps,dim=0)
+    print(f'gs_heatmaps:{gs_heatmaps.shape}, line_logits.shape:{line_logits.squeeze(1).shape}')
+    # print(f'line_targets:{line_targets.shape},{line_targets}')
+
+    # valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
+    # valid = torch.where(valid)[0]
+
+    # print(f' line_targets[valid]:{line_targets[valid]}')
 
     # torch.mean (in binary_cross_entropy_with_logits) doesn't
     # accept empty tensors, so handle it sepaartely
-    if line_targets.numel() == 0 or len(valid) == 0:
-        return line_logits.sum() * 0
+    # if line_targets.numel() == 0 or len(valid) == 0:
+    #     return line_logits.sum() * 0
 
-    line_logits = line_logits.view(N * K, H * W)
+    # line_logits = line_logits.view(N * K, H * W)
+    # print(f'line_logits[valid]:{line_logits[valid].shape}')
+    line_logits=line_logits.squeeze(1)
+
+    # line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
+    line_loss=F.cross_entropy(line_logits,gs_heatmaps)
 
-    line_loss = F.cross_entropy(line_logits[valid], line_targets[valid])
     return line_loss
 
 def line_inference(x, boxes):
@@ -353,6 +506,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
     for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
         kp = gt_kp_in_image[midx]
         heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
+
         heatmaps.append(heatmaps_per_image.view(-1))
         valid.append(valid_per_image.view(-1))
 
@@ -860,7 +1014,7 @@ class RoIHeads(nn.Module):
         if self.has_line():
             print(f'roi_heads forward has_line()!!!!')
             line_proposals = [p["boxes"] for p in result]
-            print(f'line_proposals:{len(line_proposals)}')
+            print(f'boxes_proposals:{len(line_proposals)}')
 
             # if line_proposals is None or len(line_proposals) == 0:
             #     # 返回空特征或者跳过该部分计算
@@ -892,10 +1046,15 @@ class RoIHeads(nn.Module):
                 else:
                     pos_matched_idxs = None
 
+            print(f'line_proposals:{len(line_proposals)}')
             line_features = self.line_roi_pool(features, line_proposals, image_shapes)
+            print(f'line_features from line_roi_pool:{line_features.shape}')
             line_features = self.line_head(line_features)
+            print(f'line_features from line_head:{line_features.shape}')
             line_logits = self.line_predictor(line_features)
 
+            print(f'line_logits:{line_logits.shape}')
+
             loss_line = {}
             if self.training:
                 if targets is None or pos_matched_idxs is None:

+ 2 - 2
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: G:\python_ws_g\data\250612
+  datadir: \\192.168.50.222/share/rlq/datasets/250612
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 
@@ -10,7 +10,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 2
+  batch_size: 4
   max_epoch: 80000
   optim:
     name: Adam