Преглед на файлове

debug infer circle on 4080

admin преди 4 месеца
родител
ревизия
d353064628

+ 5 - 0
libs/vision_libs/models/detection/transform.py

@@ -309,6 +309,11 @@ class GeneralizedRCNNTransform(nn.Module):
                 keypoints = pred["points"]
                 keypoints = resize_keypoints(keypoints, im_s, o_im_s)
                 result[i]["points"] = keypoints
+
+            if "circles" in pred:
+                keypoints = pred["circles"]
+                keypoints = resize_keypoints(keypoints, im_s, o_im_s)
+                result[i]["circles"] = keypoints
         return result
 
     def __repr__(self) -> str:

+ 33 - 1
models/line_detect/heads/head_losses.py

@@ -438,6 +438,38 @@ def heatmaps_to_points(maps, rois,num_points=2):
 
     return point_preds,point_end_scores
 
+def heatmaps_to_circle_points(maps, rois,num_points=2):
+
+
+    point_preds = torch.zeros((len(rois), 4, 2), dtype=torch.float32, device=maps.device)
+    point_end_scores = torch.zeros((len(rois),4, 1), dtype=torch.float32, device=maps.device)
+
+    print(f'heatmaps_to_lines:{maps.shape}')
+    point_maps=maps[:,0]
+    print(f'point_map:{point_maps.shape}')
+    for i in range(len(rois)):
+
+        point_roi_map = point_maps[i].unsqueeze(0)
+        print(f'point_roi_map:{point_roi_map.shape}')
+        # roi_map_probs = scores_to_probs(roi_map.copy())
+        w = point_roi_map.shape[2]
+        flatten_point_roi_map = non_maximum_suppression(point_roi_map).reshape(1, -1)
+        point_score, point_index = torch.topk(flatten_point_roi_map, k=num_points)
+        print(f'point index:{point_index}')
+        # pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
+
+        point_x =point_index % w
+        point_y = torch.div(point_index - point_x, w, rounding_mode="floor")
+
+
+        point_preds[i, :,0] = point_x
+        point_preds[i, :,1] = point_y
+
+        point_end_scores[i, :,0] = point_roi_map[torch.arange(1, device=point_roi_map.device), point_y, point_x]
+
+
+    return point_preds,point_end_scores
+
 def heatmaps_to_lines(maps, rois):
     line_preds = torch.zeros((len(rois), 3, 2), dtype=torch.float32, device=maps.device)
     line_end_scores = torch.zeros((len(rois), 2), dtype=torch.float32, device=maps.device)
@@ -860,7 +892,7 @@ def circle_inference(x, point_boxes):
     x2 = x.split(boxes_per_image, dim=0)
 
     for xx, bb in zip(x2, point_boxes):
-        point_prob,point_scores = heatmaps_to_points(xx, bb,num_points=4)
+        point_prob,point_scores = heatmaps_to_circle_points(xx, bb,num_points=4)
 
         points_probs.append(point_prob.unsqueeze(1))
         points_scores.append(point_scores)

+ 4 - 0
models/line_detect/line_dataset.py

@@ -236,6 +236,7 @@ def get_boxes_lines(objs,shape):
             labels.append(torch.tensor(3))
 
         elif label == 'circle' :
+            # print(f'len circle_4points: {len(obj['points'])}')
             circle_4points.append(obj['points'])
 
             xmin = max(obj['xmin'] - 6, 0)
@@ -276,7 +277,10 @@ def get_boxes_lines(objs,shape):
     if len(circle_4points)==0:
         circle_4points=None
     else:
+        # for circle_4point in circle_4points:
+            # print(f'circle_4point len111:{len(circle_4point)}')
         circle_4points=torch.tensor(circle_4points,dtype=torch.float32)
+        # print(f'circle_4points shape:{circle_4points.shape}')
 
     return boxes,line_point_pairs,points,line_mask,circle_4points, labels
 

+ 1 - 1
models/line_detect/loi_heads.py

@@ -1410,7 +1410,7 @@ class RoIHeads(nn.Module):
 
                     if feature_logits is not None:
 
-                        circles_probs, circles_scores = circle_inference(feature_logits, point_proposals)
+                        circles_probs, circles_scores = circle_inference(feature_logits, circle_proposals)
                         for keypoint_prob, kps, r in zip(circles_probs, circles_scores, result):
                             r["circles"] = keypoint_prob
                             r["circles_scores"] = kps

+ 1 - 1
models/line_detect/train.yaml

@@ -7,7 +7,7 @@ io:
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-  datadir: \\192.168.50.222/share/zyh/data/rgb_4point/a_dataset
+  datadir: /data/share/zyh/data/rgb_4point/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb

+ 78 - 0
models/line_detect/trainer.py

@@ -49,6 +49,48 @@ from torchvision.transforms import functional as F
 import torch
 
 
+
+
+def fit_circle(points):
+    """
+    Fit a circle to a set of points (at least 3).
+
+    Args:
+        points: torch.Tensor 或 numpy array, shape (N, 2)
+
+    Returns:
+        center (cx, cy), radius r
+    """
+    # 如果是 torch.Tensor,先转为 numpy
+    if isinstance(points, torch.Tensor):
+        if points.dim() == 3:
+            points = points[0]  # 去掉 batch 维度
+        points = points.detach().cpu().numpy()
+
+    if not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 2):
+        raise ValueError(f"Expected points shape (N, 2), got {points.shape}")
+
+    x = points[:, 0].astype(float)
+    y = points[:, 1].astype(float)
+
+    # 确保 A 是二维数组
+    A = np.column_stack((x, y, np.ones_like(x)))  # 使用 column_stack 代替 stack 可能更清晰
+    b = -(x ** 2 + y ** 2)
+
+    try:
+        sol, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
+    except np.linalg.LinAlgError as e:
+        print(f"Linear algebra error occurred: {e}")
+        raise ValueError("Could not fit circle to points.")
+
+    D, E, F = sol
+
+    cx = -D / 2.0
+    cy = -E / 2.0
+    r = np.sqrt(cx ** 2 + cy ** 2 - F)
+
+    return (cx, cy), r
+
 # 由低到高蓝黄红
 def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
     """
@@ -232,6 +274,42 @@ class Trainer(BaseTrainer):
             # img_tensor = np.transpose(img_tensor)
             self.writer.add_image('z-out-arc', arcs, global_step=epoch)
 
+        if 'circles' in result:
+            # points=result['circles']
+            # points=points.squeeze(1)
+            ppp=result['circles']
+            bbb=result['boxes']
+            print(f'boxes shape:{bbb.shape}')
+            print(f'ppp:{ppp.shape}')
+            points = result['circles']
+            points = points.squeeze(1)
+            print(f'points shape:{points.shape}')
+
+            circle_image = img.cpu().numpy().transpose((1, 2, 0))  # CHW -> HWC
+            circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
+
+
+            if isinstance(points, torch.Tensor):
+                points = points.cpu().numpy()
+
+            for point_set in points:
+                center, radius = fit_circle(point_set)
+                cx, cy = map(int, center)
+
+                circle_image = cv2.circle(circle_image, (cx, cy), int(radius), (0, 0, 255), 2)
+
+                for point in point_set:
+                    px, py = map(int, point)
+                    circle_image = cv2.circle(circle_image, (px, py), 3, (0, 255, 255), -1)
+
+            img_rgb = cv2.cvtColor(circle_image, cv2.COLOR_BGR2RGB)
+            img_tensor = img_rgb.transpose((2, 0, 1))  # HWC -> CHW
+            img_tensor = img_tensor / 255.0  # 归一化到 [0, 1]
+
+
+            # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
+            self.writer.add_image('z-out-circle', img_tensor, global_step=epoch)
+
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)