فهرست منبع

修复整合point和line导致报错的bug,分离point、line推理相关功能

lstrlq 5 ماه پیش
والد
کامیت
832e737f65
2فایلهای تغییر یافته به همراه74 افزوده شده و 56 حذف شده
  1. 73 55
      models/line_detect/loi_heads.py
  2. 1 1
      models/line_detect/train.yaml

+ 73 - 55
models/line_detect/loi_heads.py

@@ -569,28 +569,43 @@ def non_maximum_suppression(a):
     mask = (a == ap).float().clamp(min=0.0)
     return a * mask
 
+def heatmaps_to_points(maps, rois):
 
-def heatmaps_to_lines(maps, rois):
-    """Extract predicted keypoint locations from heatmaps. Output has shape
-    (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob)
-    for each keypoint.
-    """
-    # This function converts a discrete image coordinate in a HEATMAP_SIZE x
-    # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain
-    # consistency with keypoints_to_heatmap_labels by using the conversion from
-    # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a
-    # continuous coordinate.
-    line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
-    line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
 
     point_preds = torch.zeros((len(rois),  2), dtype=torch.float32, device=maps.device)
     point_end_scores = torch.zeros((len(rois), 1), dtype=torch.float32, device=maps.device)
 
     print(f'heatmaps_to_lines:{maps.shape}')
     point_maps=maps[:,0]
+    print(f'point_map:{point_maps.shape}')
+    for i in range(len(rois)):
+
+        point_roi_map = point_maps[i].unsqueeze(0)
+        print(f'point_roi_map:{point_roi_map.shape}')
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = point_roi_map.shape[2]
+        flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
+        point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
+        print(f'point index:{point_index}')
+        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        point_x =point_index % w
+        point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
+        point_preds[i, 0,] = point_x
+        point_preds[i, 1,] = point_y
+
+        point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
+
+
+    return point_preds,point_end_scores
+
+def heatmaps_to_lines(maps, rois):
+    line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
+    line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
+
     line_maps=maps[:,1]
 
-    print(f'point_map:{point_maps.shape}')
+
     for i in range(len(rois)):
         line_roi_map = line_maps[i].unsqueeze(0)
 
@@ -609,28 +624,13 @@ def heatmaps_to_lines(maps, rois):
         line_preds[i, 2, :] = 1
         line_end_scores[i, :] = line_roi_map[torch.arange(1, device=line_roi_map.device), line_y, line_x]
 
-        point_roi_map = point_maps[i].unsqueeze(0)
 
-        print(f'point_roi_map:{point_roi_map.shape}')
-        # roi_map_probs = scores_to_probs(roi_map.copy())
-        w = point_roi_map.shape[2]
-        flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
-        point_score, point_index = torch.topk(flatten_point_roi_map, k=1)
-        print(f'point index:{point_index}')
-        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
-
-        point_x =point_index % w
-        point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
-        point_preds[i, 0,] = point_x
-        point_preds[i, 1,] = point_y
-
-        point_end_scores[i, :] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
 
 
-    return line_preds.permute(0, 2, 1), line_end_scores,point_preds,point_end_scores
+    return line_preds.permute(0, 2, 1), line_end_scores
 
 
-def lines_features_align(features, proposals, img_size):
+def features_align(features, proposals, img_size):
     print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
 
     align_feat_list = []
@@ -866,28 +866,37 @@ def line_iou_loss(x, boxes, gt_lines, matched_idx, img_size=511, alpha=1.0, beta
     return torch.mean(torch.cat(losses))
 
 
-
-
-def line_inference(x, boxes):
+def point_inference(x, point_boxes):
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
-    lines_probs = []
-    lines_scores = []
 
     points_probs = []
     points_scores = []
 
-    boxes_per_image = [box.size(0) for box in boxes]
+    boxes_per_image = [box.size(0) for box in point_boxes]
     x2 = x.split(boxes_per_image, dim=0)
 
-    for xx, bb in zip(x2, boxes):
-        line_prob, line_scores,point_prob,point_scores = heatmaps_to_lines(xx, bb)
-        lines_probs.append(line_prob)
-        lines_scores.append(line_scores)
+    for xx, bb in zip(x2, point_boxes):
+        point_prob,point_scores = heatmaps_to_points(xx, bb)
 
         points_probs.append(point_prob.unsqueeze(1))
         points_scores.append(point_scores)
 
-    return lines_probs, lines_scores,points_probs,points_scores
+    return points_probs,points_scores
+
+def line_inference(x, line_boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+    lines_probs = []
+    lines_scores = []
+
+    boxes_per_image = [box.size(0) for box in line_boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, line_boxes):
+        line_prob, line_scores, = heatmaps_to_lines(xx, bb)
+        lines_probs.append(line_prob)
+        lines_scores.append(line_scores)
+
+    return lines_probs, lines_scores
 
 
 def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
@@ -1384,6 +1393,8 @@ class RoIHeads(nn.Module):
 
         box_features = self.box_roi_pool(features, proposals, image_shapes)
         box_features = self.box_head(box_features)
+
+
         class_logits, box_regression = self.box_predictor(box_features)
 
         result: List[Dict[str, torch.Tensor]] = []
@@ -1519,9 +1530,10 @@ class RoIHeads(nn.Module):
 
             # line_features = lines_features_align(line_features, filtered_proposals, image_shapes)
 
-            point_features = lines_features_align(cs_features, point_proposals, image_shapes)
+            point_features =features_align(cs_features, point_proposals, image_shapes)
+
 
-            line_features = lines_features_align(cs_features, line_proposals, image_shapes)
+            line_features = features_align(cs_features, line_proposals, image_shapes)
 
 
 
@@ -1593,16 +1605,16 @@ class RoIHeads(nn.Module):
 
 
 
-                    if gt_lines_tensor.shape[0] > 0:
+                    if gt_lines_tensor.shape[0] > 0 and line_features is not None:
                         loss_line = lines_point_pair_loss(
-                            line_logits, line_proposals, gt_lines, line_pos_matched_idxs
+                            line_features, line_proposals, gt_lines, line_pos_matched_idxs
                         )
-                        loss_line_iou = line_iou_loss(line_logits, line_proposals, gt_lines, line_pos_matched_idxs,
+                        loss_line_iou = line_iou_loss(line_features, line_proposals, gt_lines, line_pos_matched_idxs,
                                                       img_size)
 
-                    if gt_points_tensor.shape[0] > 0:
+                    if gt_points_tensor.shape[0] > 0 and point_features is not None:
                         loss_point = compute_point_loss(
-                            line_logits, point_proposals, gt_points, point_pos_matched_idxs
+                            point_features, point_proposals, gt_points, point_pos_matched_idxs
                         )
 
                     if not loss_line :
@@ -1623,14 +1635,20 @@ class RoIHeads(nn.Module):
                             "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                         )
 
-                    lines_probs, lines_scores,point_probs,points_scores = line_inference(line_logits, line_proposals)
+                    if line_features is not None:
+                        lines_probs, lines_scores = line_inference(line_features,line_proposals)
+                        for keypoint_prob, kps, r in zip(lines_probs, lines_scores, result):
+                            r["lines"] = keypoint_prob
+                            r["liness_scores"] = kps
+                    if point_features is not None:
+                        point_probs, points_scores=point_inference(point_features, point_proposals,)
+                        for  points, ps, r in zip(point_probs,points_scores, result):
+                            print(f'points_prob :{points.shape}')
+
+                            r["points"] = points
+                            r["points_scores"] = ps
+
 
-                    for keypoint_prob, kps, points,ps,r in zip(lines_probs, lines_scores,point_probs,points_scores, result):
-                        print(f'points_prob :{points.shape}')
-                        r["lines"] = keypoint_prob
-                        r["liness_scores"] = kps
-                        r["points"] = points
-                        r["points_scores"] = ps
 
             losses.update(loss_line)
             losses.update(loss_line_iou)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: /data/share/zjh/Dataset0709_2
+  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000