Sfoglia il codice sorgente

add circle_predictor

RenLiqiang 4 mesi fa
parent
commit
49ea21160e

+ 5 - 0
libs/vision_libs/models/detection/transform.py

@@ -222,6 +222,11 @@ class GeneralizedRCNNTransform(nn.Module):
             arc_mask = target["arc_mask"]
             arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
             target["arc_mask"] = arc_mask
+
+        if "circles" in target:
+            arc_mask = target["circles"]
+            arc_mask = resize_keypoints(arc_mask, (h, w), image.shape[-2:])
+            target["circles"] = arc_mask
         return image, target
 
     # _onnx_batch_images() is an implementation of

+ 5 - 5
models/line_detect/heads/head_losses.py

@@ -28,12 +28,12 @@ def combined_loss(preds, targets, alpha=0.5):
     return alpha * bce + (1 - alpha) * d
 
 def features_align(features, proposals, img_size):
-    print(f'lines_features_align features:{features.shape},proposals:{len(proposals)}')
+    print(f'features_align features:{features.shape},proposals:{len(proposals)}')
 
     align_feat_list = []
 
     for feat, proposals_per_img in zip(features, proposals):
-        print(f'lines_features_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
+        print(f'feature_align feat:{feat.shape}, proposals_per_img:{proposals_per_img.shape}')
         if proposals_per_img.shape[0]>0:
             feat = feat.unsqueeze(0)
             for proposal in proposals_per_img:
@@ -156,7 +156,7 @@ def points_to_heatmap(keypoints, rois,num_points=2, heatmap_size=(512,512)):
     y = keypoints[..., 1].unsqueeze(1)
 
 
-    gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=1.0)
+    gs = generate_gaussian_heatmaps(x, y,num_points=num_points, heatmap_size=heatmap_size, sigma=2.0)
     # show_heatmap(gs[0],'target')
     all_roi_heatmap = []
     for roi, heatmap in zip(rois, gs):
@@ -287,8 +287,8 @@ def generate_gaussian_heatmaps(xs, ys, heatmap_size,num_points=2, sigma=2.0, dev
 
     assert xs.shape == ys.shape, "x and y must have the same shape"
     print(f'xs:{xs.shape}')
-    # xs=xs.squeeze(1)
-    # ys = ys.squeeze(1)
+    xs=xs.squeeze(1)
+    ys = ys.squeeze(1)
     print(f'xs1:{xs.shape}')
     N = xs.shape[0]
     print(f'N:{N},num_points:{num_points}')

+ 3 - 3
models/line_detect/line_dataset.py

@@ -112,7 +112,7 @@ class LineDataset(BaseDataset):
         #     print(f'not arc_mask dataset')
 
         if circle_4points is not None:
-            target['circle']=circle_4points
+            target['circles']=circle_4points
 
         target["boxes"]=boxes
         target["labels"]=labels
@@ -248,7 +248,7 @@ def get_boxes_lines(objs,shape):
 
             boxes.append([xmin, ymin, xmax, ymax])
 
-            labels.append(torch.tensor(3))
+            labels.append(torch.tensor(4))
 
     boxes=torch.tensor(boxes)
     print(f'boxes:{boxes.shape}')
@@ -283,4 +283,4 @@ def get_boxes_lines(objs,shape):
 if __name__ == '__main__':
     path=r"\\192.168.50.222/share/zyh/data/rgb_4point/a_dataset"
     dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg')
-    dataset.show(1,show_type='all')
+    dataset.show(99,show_type='all')

+ 3 - 3
models/line_detect/line_detect.py

@@ -108,7 +108,7 @@ class LineDetect(BaseDetectionNet):
             arc_roi_pool=None,
             arc_head=None,
             arc_predictor=None,
-            num_points=3,
+            num_points=4,
             detect_point=False,
             detect_line=False,
             detect_arc=True,
@@ -396,9 +396,9 @@ def linedetect_newresnet50fpn(
     # weights = LineNet_ResNet50_FPN_Weights.verify(weights)
     # weights_backbone = ResNet50_Weights.verify(weights_backbone)
     if num_classes is None:
-        num_classes = 4
+        num_classes = 5
     if num_points is None:
-        num_points = 3
+        num_points = 4
 
     size=768
     backbone =resnet50fpn(out_channels=256)

+ 8 - 8
models/line_detect/loi_heads.py

@@ -1323,7 +1323,7 @@ class RoIHeads(nn.Module):
                 if matched_idxs is None:
                     raise ValueError("if in trainning, matched_idxs should not be None")
                 for img_id in range(num_images):
-                    circle_pos = torch.where(labels[img_id] == 1)[0]
+                    circle_pos = torch.where(labels[img_id] == 4)[0]
                     circle_proposals.append(proposals[img_id][circle_pos])
                     circle_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
             else:
@@ -1338,7 +1338,7 @@ class RoIHeads(nn.Module):
                         raise ValueError("if in trainning, matched_idxs should not be None")
 
                     for img_id in range(num_images):
-                        circle_pos = torch.where(labels[img_id] == 1)[0]
+                        circle_pos = torch.where(labels[img_id] == 4)[0]
                         circle_proposals.append(proposals[img_id][circle_pos])
                         circle_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
 
@@ -1354,7 +1354,7 @@ class RoIHeads(nn.Module):
                 if targets is None or circle_pos_matched_idxs is None:
                     raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")
 
-                gt_circles = [t["circle"] for t in targets if "circle" in t]
+                gt_circles = [t["circles"] for t in targets if "circles" in t]
 
                 print(f'gt_circle:{gt_circles[0].shape}')
                 h, w = targets[0]["img_size"]
@@ -1374,13 +1374,13 @@ class RoIHeads(nn.Module):
                     print(f'loss_circle is None111')
                     loss_circle = torch.tensor(0.0, device=device)
 
-                loss_point = {"loss_circle": loss_circle}
+                loss_circle = {"loss_circle": loss_circle}
 
             else:
                 if targets is not None:
                     h, w = targets[0]["img_size"]
                     img_size = h
-                    gt_circles = [t["circle"] for t in targets if "circle" in t]
+                    gt_circles = [t["circles"] for t in targets if "circles" in t]
 
                     gt_circles_tensor = torch.zeros(0, 0)
                     if len(gt_circles) > 0:
@@ -1390,7 +1390,7 @@ class RoIHeads(nn.Module):
                     if gt_circles_tensor.shape[0] > 0:
                         print(f'start to compute circle_loss')
 
-                        loss_circle = compute_circle_loss(feature_logits, point_proposals, gt_circles,
+                        loss_circle = compute_circle_loss(feature_logits, circle_proposals, gt_circles,
                                                         circle_pos_matched_idxs)
 
                     if loss_circle is None:
@@ -1599,7 +1599,7 @@ class RoIHeads(nn.Module):
         return roi_features
 
     def circle_forward1(self, features, image_shapes, proposals):
-        print(f'point_proposals:{len(proposals)}')
+        print(f'circle_proposals:{len(proposals)}')
         # cs_features= features['0']
         # print(f'features-0:{features['0'].shape}')
         # cs_features = self.channel_compress(features['0'])
@@ -1614,7 +1614,7 @@ class RoIHeads(nn.Module):
         # print(f'point_proposals_tensor:{point_proposals_tensor.shape}')
 
         feature_logits = self.circle_predictor(cs_features)
-        print(f'feature_logits from line_head:{feature_logits.shape}')
+        print(f'feature_logits from circle_head:{feature_logits.shape}')
 
         roi_features = features_align(feature_logits, proposals, image_shapes)
         if roi_features is not None:

+ 1 - 1
models/line_detect/train_demo.py

@@ -18,7 +18,7 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     # model=linedetect_newresnet18fpn(num_points=3)
-    model=linedetect_newresnet50fpn(num_points=3)
+    model=linedetect_newresnet50fpn(num_points=4)
     # model = linedetect_newresnet101fpn(num_points=3)
     # model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')