zyhhsss 4 miesięcy temu
rodzic
commit
30cb15cb89

+ 2 - 2
models/base/transforms.py

@@ -1,6 +1,6 @@
 import logging
 import random
-from typing import Any
+from typing import Any,Tuple
 
 import cv2
 import numpy as np
@@ -455,7 +455,7 @@ class RandomPerspective:
         return img, target
 
 class DefaultTransform(nn.Module):
-    def forward(self, img: Tensor,target) -> tuple[Tensor, Any]:
+    def forward(self, img: Tensor, target) -> Tuple[Tensor, Any]:
         if not isinstance(img, Tensor):
             img = F.pil_to_tensor(img)
         return F.convert_image_dtype(img, torch.float),target

+ 104 - 4
models/line_detect/heads/head_losses.py

@@ -204,7 +204,7 @@ def line_points_to_heatmap_(keypoints, rois, heatmap_size):
 
     gs_heatmap = generate_gaussian_heatmaps(x, y, heatmap_size, 1.0)
 
-    show_heatmap(gs_heatmap[0], 'feature')
+    # show_heatmap(gs_heatmap[0], 'feature')
 
     # print(f'gs_heatmap:{gs_heatmap.shape}')
     #
@@ -546,13 +546,13 @@ def arc_points_to_heatmap(keypoints, rois, heatmap_size):
     # show_heatmap(gs[0],'target')
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):
-        show_heatmap(heatmap, 'target')
+        # show_heatmap(heatmap, 'target')
         print(f'heatmap:{heatmap.shape}')
         heatmap = heatmap.unsqueeze(0)
         x1, y1, x2, y2 = map(int, roi)
         roi_heatmap = torch.zeros_like(heatmap)
         roi_heatmap[..., y1:y2 + 1, x1:x2 + 1] = heatmap[..., y1:y2 + 1, x1:x2 + 1]
-        show_heatmap(roi_heatmap[0],'roi_heatmap')
+        # show_heatmap(roi_heatmap[0],'roi_heatmap')
         all_roi_heatmap.append(roi_heatmap)
 
     all_roi_heatmap = torch.cat(all_roi_heatmap)
@@ -742,10 +742,110 @@ def line_inference(x, line_boxes):
 
     boxes_per_image = [box.size(0) for box in line_boxes]
     x2 = x.split(boxes_per_image, dim=0)
-
+    # x2:tuple 2 x2[0]:[1,3,1024,1024]
+    # line_box: list:2 [1,4] [1.4] fasterrcnn kuang
     for xx, bb in zip(x2, line_boxes):
         line_prob, line_scores, = heatmaps_to_lines(xx, bb)
         lines_probs.append(line_prob)
         lines_scores.append(line_scores)
 
     return lines_probs, lines_scores
+
+def arc_inference(x, point_boxes):
+    # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+
+    points_probs = []
+    points_scores = []
+
+    boxes_per_image = [box.size(0) for box in point_boxes]
+    x2 = x.split(boxes_per_image, dim=0)
+
+    for xx, bb in zip(x2, point_boxes):
+        point_prob,point_scores = heatmaps_to_arc(xx, bb)
+
+        points_probs.append(point_prob.unsqueeze(1))
+        points_scores.append(point_scores)
+
+    return points_probs,points_scores
+
+import torch.nn.functional as F
+
+import torch.nn.functional as F
+
+def heatmaps_to_arc(maps, rois, threshold=0.1, output_size=(128, 128)):
+    """
+    Args:
+        maps: [N, 3, H, W] - full heatmaps
+        rois: [N, 4] - bounding boxes
+        threshold: float - binarization threshold
+        output_size: resized size for uniform NMS
+
+    Returns:
+        masks: [N, 1, H, W] - binary mask aligned with input map
+        scores: [N, 1] - count of non-zero pixels in each mask
+    """
+    N, _, H, W = maps.shape
+    masks = torch.zeros((N, 1, H, W), dtype=torch.float32, device=maps.device)
+    scores = torch.zeros((N, 1), dtype=torch.float32, device=maps.device)
+
+    point_maps = maps[:, 0]  # È¡µÚÒ»¸öͨµÀ [N, H, W]
+
+    print(f"==> heatmaps_to_arc: maps.shape = {maps.shape}, rois.shape = {rois.shape}")
+
+    for i in range(N):
+        x1, y1, x2, y2 = rois[i].long()
+        x1 = x1.clamp(0, W - 1)
+        x2 = x2.clamp(0, W - 1)
+        y1 = y1.clamp(0, H - 1)
+        y2 = y2.clamp(0, H - 1)
+
+        print(f"[{i}] roi: ({x1.item()}, {y1.item()}, {x2.item()}, {y2.item()})")
+
+        if x2 <= x1 or y2 <= y1:
+            print(f"    Skipped invalid ROI at index {i}")
+            continue
+
+        roi_map = point_maps[i, y1:y2, x1:x2]  # [h, w]
+        print(f"    roi_map.shape: {roi_map.shape}")
+
+        if roi_map.numel() == 0:
+            print(f"    Skipped empty ROI at index {i}")
+            continue
+
+        # resize to uniform size
+        roi_map_resized = F.interpolate(
+            roi_map.unsqueeze(0).unsqueeze(0),
+            size=output_size,
+            mode='bilinear',
+            align_corners=False
+        )  # [1, 1, H, W]
+        print(f"    roi_map_resized.shape: {roi_map_resized.shape}")
+
+        # NMS + threshold
+        nms_roi = non_maximum_suppression(roi_map_resized)  # shape: [1, H, W]
+        bin_mask = (nms_roi > threshold).float()  # shape: [1, H, W]
+        print(f"    bin_mask.sum(): {bin_mask.sum().item()}")
+
+        # resize back to original roi size
+        h = int((y2 - y1).item())
+        w = int((x2 - x1).item())
+        # È·±£ bin_mask ÊÇ [1, 128, 128]
+        assert bin_mask.dim() == 4, f"Expected 3D tensor [1, H, W], got {bin_mask.shape}"
+
+        # ÉϲÉÑù»Ø ROI ԭʼ´óС
+        bin_mask_original_size = F.interpolate(
+            # bin_mask.unsqueeze(0),  # ? [1, 1, 128, 128]
+            bin_mask,  # ? [1, 1, 128, 128]
+            size=(h, w),
+            mode='bilinear',
+            align_corners=False
+        )[0]  # ? [1, h, w]
+
+        masks[i, 0, y1:y2, x1:x2] = bin_mask_original_size.squeeze()
+        scores[i] = bin_mask_original_size.sum()
+
+        print(f"    bin_mask_original_size.shape: {bin_mask_original_size.shape}, sum: {scores[i].item()}")
+
+    print(f"==> Done. Total valid masks: {(scores > 0).sum().item()} / {N}")
+
+    return masks, scores

+ 3 - 2
models/line_detect/line_dataset.py

@@ -96,6 +96,7 @@ class LineDataset(BaseDataset):
         target["image_id"] = torch.tensor(item)
         boxes, lines, points, arc_mask,labels = get_boxes_lines(objs, shape)
 
+
         if points is not None:
             target["points"]=points
         if lines is not None:
@@ -153,7 +154,7 @@ class LineDataset(BaseDataset):
         #     plt.show()
 
         if show_type=='points':
-            print(f'points:{target['points'].shape}')
+            # print(f'points:{target['points'].shape}')
             keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['points'].unsqueeze(1),colors='red',width=3)
             plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
             plt.show()
@@ -211,7 +212,7 @@ def get_boxes_lines(objs,shape):
 
 
 
-        elif label == 'arc':
+        elif label == 'arc' :
 
             line_mask.append(obj['points'])
 

+ 5 - 5
models/line_detect/line_detect.py

@@ -104,8 +104,8 @@ class LineDetect(BaseDetectionNet):
             arc_predictor=None,
             num_points=3,
             detect_point=False,
-            detect_line=True,
-            detect_arc=False,
+            detect_line=False,
+            detect_arc=True,
             **kwargs,
 
     ):
@@ -194,7 +194,7 @@ class LineDetect(BaseDetectionNet):
             arc_head=ArcHeads(8,layers)
         if detect_arc and arc_predictor is None:
             layers = tuple(num_points for _ in range(8))
-            arc_predictor=ArcPredictor(in_channels=128)
+            arc_predictor=ArcPredictor(in_channels=256)
 
 
 
@@ -368,7 +368,7 @@ def linedetect_newresnet50fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 3
+        num_classes = 4
     if num_points is None:
         num_points = 3
 
@@ -447,7 +447,7 @@ def linedetect_newresnet152fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 3
+        num_classes = 4
     if num_points is None:
         num_points = 3
 

+ 19 - 19
models/line_detect/loi_heads.py

@@ -13,7 +13,7 @@ import libs.vision_libs.models.detection._utils as det_utils
 from collections import OrderedDict
 
 from models.line_detect.heads.head_losses import point_inference, compute_point_loss, line_iou_loss, \
-    lines_point_pair_loss, features_align, line_inference, compute_arc_loss
+    lines_point_pair_loss, features_align, line_inference, compute_arc_loss, arc_inference
 
 
 def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
@@ -954,7 +954,7 @@ class RoIHeads(nn.Module):
                 gt_lines = [t["lines"] for t in targets if "lines" in t]
 
 
-                print(f'gt_lines:{gt_lines[0].shape}')
+                # print(f'gt_lines:{gt_lines[0].shape}')
                 h, w = targets[0]["img_size"]
                 img_size = h
 
@@ -1161,6 +1161,7 @@ class RoIHeads(nn.Module):
             # print(f'labels:{labels}')
             arc_proposals = [p["boxes"] for p in result]
             print(f'boxes_proposals:{len(arc_proposals)}')
+            print(f'boxes_proposals:{len(arc_proposals)}')
 
             # if line_proposals is None or len(line_proposals) == 0:
             #     # 返回空特征或者跳过该部分计算
@@ -1195,7 +1196,7 @@ class RoIHeads(nn.Module):
                         arc_pos_matched_idxs.append(matched_idxs[img_id][arc_pos])
 
                 else:
-                    pos_matched_idxs = None
+                    arc_pos_matched_idxs = None
 
             feature_logits = self.arc_forward1(features, image_shapes, arc_proposals)
 
@@ -1212,15 +1213,15 @@ class RoIHeads(nn.Module):
                 h, w = targets[0]["img_size"]
                 img_size = h
 
-                gt_arcs_tensor = torch.zeros(0, 0)
+                # gt_arcs_tensor = torch.zeros(0, 0)
                 # if len(gt_arcs) > 0:
-                    # gt_arcs_tensor = torch.cat(gt_arcs)
-                    # print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
-
+                #     gt_arcs_tensor = torch.cat(gt_arcs)
+                #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
+                #
                 # if gt_arcs_tensor.shape[0] > 0:
                 #     print(f'start to compute point_loss')
-
-                loss_arc=compute_arc_loss(feature_logits,arc_proposals,gt_arcs,arc_pos_matched_idxs)
+                if len(gt_arcs) > 0 and feature_logits is not None:
+                    loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
                 if loss_arc is None:
                     print(f'loss_arc is None111')
@@ -1243,9 +1244,8 @@ class RoIHeads(nn.Module):
                         gt_arcs_tensor = torch.cat(gt_arcs)
                         print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
 
-                    if gt_arcs_tensor.shape[0] > 0:
-                        print(f'start to compute point_loss')
-
+                    if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
+                        print(f'start to compute arc_loss')
                         loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
                     if loss_arc is None:
@@ -1270,8 +1270,8 @@ class RoIHeads(nn.Module):
                             r["arcs"] = keypoint_prob
                             r["arcs_scores"] = kps
 
-            print(f'loss_point:{loss_point}')
-            losses.update(loss_point)
+            # print(f'loss_point:{loss_point}')
+            losses.update(loss_arc)
             print(f'losses:{losses}')
 
 
@@ -1369,7 +1369,7 @@ class RoIHeads(nn.Module):
     def line_forward1(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']
-        print(f'features-0:{features['0'].shape}')
+        # print(f'features-0:{features['0'].shape}')
         cs_features = self.channel_compress(features['0'])
         filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
         if len(filtered_proposals) > 0:
@@ -1387,7 +1387,7 @@ class RoIHeads(nn.Module):
     def line_forward2(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']
-        print(f'features-0:{features['0'].shape}')
+        # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
         cs_features=features['0']
         filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
@@ -1410,7 +1410,7 @@ class RoIHeads(nn.Module):
     def line_forward3(self, features, image_shapes, line_proposals):
         print(f'line_proposals:{len(line_proposals)}')
         # cs_features= features['0']
-        print(f'features-0:{features['0'].shape}')
+        # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
         cs_features=features['0']
         # filtered_proposals = [proposal for proposal in line_proposals if proposal.shape[0] > 0]
@@ -1433,7 +1433,7 @@ class RoIHeads(nn.Module):
     def point_forward1(self, features, image_shapes, proposals):
         print(f'point_proposals:{len(proposals)}')
         # cs_features= features['0']
-        print(f'features-0:{features['0'].shape}')
+        # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
         cs_features=features['0']
         # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]
@@ -1457,7 +1457,7 @@ class RoIHeads(nn.Module):
     def arc_forward1(self, features, image_shapes, proposals):
         print(f'point_proposals:{len(proposals)}')
         # cs_features= features['0']
-        print(f'features-0:{features['0'].shape}')
+        # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
         cs_features=features['0']
         # filtered_proposals = [proposal for proposal in proposals if proposal.shape[0] > 0]

+ 3 - 3
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
 
-  datadir: /data/share/rlq/datasets/250718caisegangban
+  datadir: /data/zyh/py_ws/code/a_dataset
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
@@ -22,8 +22,8 @@ train_params:
   num_workers: 8
   batch_size: 2
   max_epoch: 8000000
-  augmentation: True
-#  augmentation: False
+#  augmentation: True
+  augmentation: False
   optim:
     name: Adam
     lr: 4.0e-4