소스 검색

add lines_generator

RenLiqiang 5 달 전
부모
커밋
eb72a21f32

+ 118 - 2
libs/vision_libs/models/detection/rpn.py

@@ -33,6 +33,8 @@ class RPNHead(nn.Module):
         self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
         self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
 
+        self.line_pred=nn.Conv2d(in_channels, num_anchors * 2, kernel_size=1, stride=1)
+
         for layer in self.modules():
             if isinstance(layer, nn.Conv2d):
                 torch.nn.init.normal_(layer.weight, std=0.01)  # type: ignore[arg-type]
@@ -359,6 +361,11 @@ class RegionProposalNetwork(torch.nn.Module):
         features = list(features.values())
 
         objectness, pred_bbox_deltas = self.head(features)
+        for obj in objectness:
+            print(f'objectness:{obj.shape}')
+
+        for pred_bbox in pred_bbox_deltas:
+            print(f'pred_bbox:{pred_bbox.shape}')
 
         anchors = self.anchor_generator(images, features)
 
@@ -366,13 +373,27 @@ class RegionProposalNetwork(torch.nn.Module):
         num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
         num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
         objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
+
+
+
+
         # apply pred_bbox_deltas to anchors to obtain the decoded proposals
         # note that we detach the deltas because Faster R-CNN do not backprop through
         # the proposals
         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
+        print(f'box_coder.decode proposals:{proposals.shape}')
         proposals = proposals.view(num_images, -1, 4)
         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
-        # print(f'boxes:{boxes.shape},scores:{scores.shape}')
+        print(f'boxes:{boxes[0].shape},scores:{scores[0].shape}')
+
+        lines=self.lines_generator(features,300)
+
+        # 合并所有线段为一个 Tensor(假设 batch_size=2)
+        lines_all = torch.cat(lines, dim=0)  # [Total_Lines, 4]
+
+        # 过滤出在 boxes 内的线段
+        lines =self.filter_lines_inside_boxes(lines_all, boxes)
+
 
         losses = {}
         if self.training:
@@ -388,4 +409,99 @@ class RegionProposalNetwork(torch.nn.Module):
                 "loss_rpn_box_reg": loss_rpn_box_reg,
             }
         # print(f'boxes:{boxes[0].shape}')
-        return boxes, losses
+        return boxes,losses,lines
+
+    def lines_generator(self, features: torch.Tensor, topk=300):
+        """
+        Args:
+            features (Tensor): shape [B, C, H, W], 其中 C >= 3
+                - features[:, 0]: jmap (junction map)
+                - features[:, 1:3]: joff (offsets in x and y)
+            topk (int): 提取热度最高的前 K 个点
+
+        Returns:
+            lines_batch (List[Tensor]): 每个元素是一个 [N, 4] 的 Tensor 表示该图像中的线段
+        """
+        features=features[0]
+        B, _, H, W = features.shape
+        lines_batch = []
+
+        jmap = features[:, 0]  # shape: [B, H, W]
+        joff = features[:, 1:3]  # shape: [B, 2, H, W]
+
+        for b in range(B):
+            jmap_b = jmap[b]  # shape: [H, W]
+            joff_b = joff[b]  # shape: [2, H, W]
+
+            # Flatten 并取 top-k 热点
+            val_k, idx_k = torch.topk(jmap_b.view(-1), k=topk)
+            ys = idx_k // W  # 行号
+            xs = idx_k % W  # 列号
+
+            # 获取偏移值
+            dx = joff_b[0, ys, xs]
+            dy = joff_b[1, ys, xs]
+
+            # 校正坐标
+            points = torch.stack([
+                xs.float() + dx,
+                ys.float() + dy
+            ], dim=1)  # shape: [topk, 2]
+
+            # 两两组合成线段
+            num_points = points.shape[0]
+            if num_points < 2:
+                lines_batch.append(torch.empty((0, 4), device=features.device))
+                continue
+
+            idx_i, idx_j = torch.triu_indices(num_points, num_points, offset=1)
+            point_i = points[idx_i]
+            point_j = points[idx_j]
+            lines = torch.cat([point_i, point_j], dim=1)  # shape: [N, 4]
+
+            lines_batch.append(lines)
+
+        print(f'lines_batch:{lines_batch[0].shape}')
+        return lines_batch
+
+    def filter_lines_inside_boxes(self,lines: torch.Tensor, boxes: List[torch.Tensor]):
+        """
+        Args:
+            lines: [N, 4] 线段,格式为 [x1, y1, x2, y2]
+            boxes: List of [K_i, 4],每张图像的 proposal boxes
+
+        Returns:
+            filtered_lines_per_image: List[Tensor], 每个元素是该图像中位于 box 内的线段
+        """
+        filtered_lines = []
+
+        for box in boxes:
+            # box shape: [K, 4]
+            line_masks = []
+
+            for i in range(box.shape[0]):
+                bx0, by0, bx1, by1 = box[i]
+
+                # 获取线段两端点
+                x1, y1, x2, y2 = lines[:, 0], lines[:, 1], lines[:, 2], lines[:, 3]
+
+                # 判断两个端点是否都在 box 内
+                in_box1 = (x1 >= bx0) & (y1 >= by0) & (x1 <= bx1) & (y1 <= by1)
+                in_box2 = (x2 >= bx0) & (y2 >= by0) & (x2 <= bx1) & (y2 <= by1)
+
+                mask = in_box1 & in_box2  # 两个端点都在 box 内
+                line_masks.append(mask)
+
+            if len(line_masks) == 0:
+                filtered_lines.append(torch.empty((0, 4), device=lines.device))
+            else:
+                combined_mask = torch.stack(line_masks).any(dim=0)  # 只要在一个 box 内即可
+                filtered_line = lines[combined_mask]
+                filtered_lines.append(filtered_line)
+
+        return filtered_lines
+
+def non_maximum_suppression(a):
+    ap = F.max_pool2d(a, 3, stride=1, padding=1)
+    mask = (a == ap).float().clamp(min=0.0)
+    return a * mask

+ 3 - 2
models/base/base_detection_net.py

@@ -112,14 +112,15 @@ class BaseDetectionNet(BaseModel):
 
         if isinstance(features, torch.Tensor):
             features = OrderedDict([("0", features)])
-        proposals, proposal_losses = self.rpn(images, features, targets)
+        proposals, proposal_losses,lines = self.rpn(images, features, targets)
 
 
         # print(f'proposals:{proposals[0].shape}')
 
 
 
-        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
+        detections, detector_losses = self.roi_heads(features, proposals, lines, images.image_sizes, targets)
+
         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
 
         # ->multi task head

+ 2 - 0
models/line_detect/roi_heads.py

@@ -1015,6 +1015,7 @@ class RoIHeads(nn.Module):
             self,
             features,  # type: Dict[str, Tensor]
             proposals,  # type: List[Tensor]
+            lines,
             image_shapes,  # type: List[Tuple[int, int]]
             targets=None,  # type: Optional[List[Dict[str, Tensor]]]
     ):
@@ -1083,6 +1084,7 @@ class RoIHeads(nn.Module):
                             "boxes": boxes[i],
                             "labels": labels[i],
                             "scores": scores[i],
+                            "lines":lines[i],
                         }
                     )
 

+ 1 - 1
models/line_detect/train_demo.py

@@ -13,7 +13,7 @@ if __name__ == '__main__':
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     # model=linenet_newresnet50fpn()
     model = linenet_newresnet18fpn()
-    model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
+    # model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.start_train(cfg='train.yaml')

+ 3 - 2
models/line_detect/trainer.py

@@ -166,8 +166,9 @@ class Trainer(BaseTrainer):
         # plt.imshow(lmap)
         # plt.show()
         H = result[-1]['wires']
-        lines = H["lines"][0].cpu().numpy()
-        scores = H["score"][0].cpu().numpy()
+        # lines = H["lines"][0].cpu().numpy()
+        lines=result[0]["lines"]
+        scores =100
         for i in range(1, len(lines)):
             if (lines[i] == lines[0]).all():
                 lines = lines[:i]