zhaoyinghan 3 недель назад
Родитель
Сommit
774a29460c

+ 53 - 47
models/line_detect/heads/arc/arc_heads.py

@@ -43,75 +43,81 @@ class ArcPredictor(nn.Module):
         # )
 
 
+
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 
 class ArcEquationHead(nn.Module):
-    def __init__(self, num_outputs=7):
+    """
+    Input:
+        feature_logits : [N, 1, H, W]
+
+    Output:
+        arc_params : [N, 7]
+            0: center x (cx)
+            1: center y (cy)
+            2: long axis length (a)
+            3: short axis length (b)
+            4: ellipse angle (theta)
+            5: auxiliary x coordinate
+            6: auxiliary y coordinate
+    """
+
+    def __init__(self, num_outputs=9, hidden=512):
         super().__init__()
 
-        # --------------------------------------------------
-        # Convolution layers - no fixed H,W assumptions
-        # Automatically downsamples using stride=2
-        # --------------------------------------------------
-        self.conv = nn.Sequential(
-            nn.Conv2d(1, 32, 3, stride=2, padding=1),
-            nn.ReLU(inplace=True),
-
-            nn.Conv2d(32, 64, 3, stride=2, padding=1),
-            nn.ReLU(inplace=True),
-
-            nn.Conv2d(64, 128, 3, stride=2, padding=1),
-            nn.ReLU(inplace=True),
-
-            nn.Conv2d(128, 256, 3, stride=2, padding=1),
-            nn.ReLU(inplace=True),
-        )
-
-        # --------------------------------------------------
-        # Global pooling ¡ú no H,W dependency
-        # --------------------------------------------------
+        # Use GAP to remove spatial dependency
         self.gap = nn.AdaptiveAvgPool2d((1, 1))
 
-        # --------------------------------------------------
-        # MLP
-        # --------------------------------------------------
+        # Final MLP that maps pooled feature ¡ú arc parameters
         self.mlp = nn.Sequential(
-            nn.Linear(256, 256),
+            nn.Linear(1, hidden),
             nn.ReLU(inplace=True),
-            nn.Linear(256, num_outputs)
+            nn.Linear(hidden, num_outputs)
         )
 
-
     def forward(self, feature_logits):
         """
-        Args:
-            feature_logits: Tensor [N, 1, H, W]
+        feature_logits: [N, 1, H, W]
         """
+        N, _, H, W = feature_logits.shape
 
-        # CNN
-        x = self.conv(feature_logits)
-
-        # Global pool
-        x = self.gap(x).view(x.size(0), -1)
-
-        # Predict params
-        arc_params = self.mlp(x)   # -> [N, 7]
+        # --------------------------------------------
+        # Global average pooling
+        # Input  : [N, 1, H, W]
+        # Output : [N, 1]
+        # --------------------------------------------
+        x = self.gap(feature_logits)
+        x = x.view(N, -1)
 
-        N, _, H, W = feature_logits.shape
+        # Predict raw parameters
+        arc_params = self.mlp(x)   # [N, 7]
 
         # --------------------------------------------
-        # Apply constraints
+        # Parameter constraints
         # --------------------------------------------
-        arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W   # cx
-        arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H   # cy
 
-        arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6        # long axis
-        arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6        # short axis
+        # Ellipse center
+        arc_params[..., 0] = torch.sigmoid(arc_params[..., 0]) * W   # cx in image width range
+        arc_params[..., 1] = torch.sigmoid(arc_params[..., 1]) * H   # cy in image height range
+
+        # Axes lengths must be positive
+        arc_params[..., 2] = F.relu(arc_params[..., 2]) + 1e-6       # a > 0
+        arc_params[..., 3] = F.relu(arc_params[..., 3]) + 1e-6       # b > 0
+
+        # Angle between 0~2¦Ð
+        arc_params[..., 4] = torch.sigmoid(arc_params[..., 4]) * (2 * 3.1415926535)
 
-        # angles 0~2¦Ð
-        arc_params[..., 4:7] = torch.sigmoid(arc_params[..., 4:7]) * (2 * 3.1415926535)
+        # ------------------------------------------------
+        # Last two values are auxiliary points
+        # Now mapped to the same spatial range as image
+        # ------------------------------------------------
+        arc_params[..., 5] = torch.sigmoid(arc_params[..., 5]) * W   # x auxiliary
+        arc_params[..., 6] = torch.sigmoid(arc_params[..., 6]) * H   # y auxiliary
+        arc_params[..., 7] = torch.sigmoid(arc_params[..., 7]) * W  # x auxiliary
+        arc_params[..., 8] = torch.sigmoid(arc_params[..., 8]) * H  # y auxiliary
 
-        return arc_params
+        return arc_params

+ 7 - 5
models/line_detect/line_dataset.py

@@ -71,6 +71,8 @@ class LineDataset(BaseDataset):
             img = PIL.Image.open(img_path).convert('RGB')
             w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        print(img_path)
+        print_params(img)
         target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
 
         self.transforms = get_transforms(augmention=self.augmentation)
@@ -101,8 +103,8 @@ class LineDataset(BaseDataset):
 
         # print_params(arc_ends, arc_params)
 
-        if points is not None:
-            target["points"] = points
+        # if points is not None:
+            # target["points"] = points
         # if lines is not None:
         #     a = torch.full((lines.shape[0],), 2).unsqueeze(1)
         #     lines = torch.cat((lines, a), dim=1)
@@ -122,12 +124,12 @@ class LineDataset(BaseDataset):
             # arc_angles = compute_arc_angles(arc_ends, arc_params)
 
 
-            print_params(arc_ends,arc_params)
+            # print_params(arc_ends,arc_params)
             arc_masks = []
             for i in range(len(arc_params)):
-                mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000))
+                mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000),debug=False)
                 arc_masks.append(mask)
-            print_params(arc_masks)
+            # print_params(arc_masks)
             target['circle_masks'] = torch.stack(arc_masks, dim=0)
 
             # save_full_mask(torch.stack(arc_masks, dim=0), "arc_masks",

+ 41 - 31
models/line_detect/loi_heads.py

@@ -1395,8 +1395,12 @@ class RoIHeads(nn.Module):
 
                 print(f'features from backbone:{features['0'].shape}')
                 feature_logits = self.ins_forward1(features, image_shapes, ins_proposals)
+                # ins_masks, ins_scores, circle_points = ins_inference(feature_logits,
+                #                                                      ins_proposals, th=0)
 
-                # arc_equation = self.arc_equation_head(feature_logits)  # [proposal和,7]
+
+
+                arc_equation = self.arc_equation_head(feature_logits)  # [proposal和,7]
 
                 loss_ins = None
                 loss_ins_extra=None
@@ -1430,7 +1434,7 @@ class RoIHeads(nn.Module):
                         print(f'start to compute circle_loss')
 
                         loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,ins_pos_matched_idxs)
-                        # total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_params,ins_pos_matched_idxs,labels)
+                        total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
                         loss_arc_ends = loss_arc_ends
                     if loss_arc_equation is None:
                         print(f'loss_arc_equation is None')
@@ -1472,7 +1476,7 @@ class RoIHeads(nn.Module):
 
                             loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,
                                                            ins_pos_matched_idxs)
-                            # total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_params,ins_pos_matched_idxs,labels)
+                            total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_ends,gt_mask_params,ins_pos_matched_idxs,labels)
 
                             loss_arc_ends = loss_arc_ends
 
@@ -1524,10 +1528,10 @@ class RoIHeads(nn.Module):
                             ins_masks, ins_scores, circle_points = ins_inference(feature_logits,
                                                                                          ins_proposals, th=0)
 
-                            # arc7, arc_scores = arc_inference1(arc_equation, feature_logits, ins_proposals, 0.5)
-                            # for arc_, arc_score, r in zip(arc7, arc_scores, result):
-                            #     r["arcs"] = arc_
-                            #     r["arc_scores"] = arc_score
+                            arc7, arc_scores = arc_inference1(arc_equation, feature_logits, ins_proposals, 0.5)
+                            for arc_, arc_score, r in zip(arc7, arc_scores, result):
+                                r["arcs"] = arc_
+                                r["arc_scores"] = arc_score
                             # print(f'circles_probs:{circles_probs.shape}, circles_scores:{circles_scores.shape}')
                             proposals_per_image = [box.size(0) for box in ins_proposals]
                             print(f'ins_proposals_per_image:{proposals_per_image}')
@@ -1815,31 +1819,31 @@ def compute_arc_equation_loss(arc_equation, proposals, gt_mask_ends, gt_mask_par
         f'compute_arc_equation_loss line_logits.shape:{arc_equation.shape},len_proposals:{len_proposals},line_matched_idxs:{arc_pos_matched_idxs}')
     print(f'gt_mask_ends:{gt_mask_ends}, gt_mask_params:{gt_mask_params}')
 
-    gt_angles = []
-    # for gt_mask_end,gt_mask_param in zip(gt_mask_ends, gt_mask_params):
-    #     print(f'gt_mask_end:{gt_mask_end}, gt_mask_param:{gt_mask_param}')
-    #     gt_angles.append(compute_arc_angles(gt_mask_end,gt_mask_param))
-    for i in range(len(gt_mask_ends)):
-        print(f'gt_mask_end:{gt_mask_ends[i]}, gt_mask_param:{gt_mask_params[i]}')
-        gt_angles.append(compute_arc_angles(gt_mask_ends[i], gt_mask_params[i]))
+    # gt_angles = []
+    # # for gt_mask_end,gt_mask_param in zip(gt_mask_ends, gt_mask_params):
+    # #     print(f'gt_mask_end:{gt_mask_end}, gt_mask_param:{gt_mask_param}')
+    # #     gt_angles.append(compute_arc_angles(gt_mask_end,gt_mask_param))
+    # for i in range(len(gt_mask_ends)):
+    #     print(f'gt_mask_end:{gt_mask_ends[i]}, gt_mask_param:{gt_mask_params[i]}')
+    #     gt_angles.append(compute_arc_angles(gt_mask_ends[i], gt_mask_params[i]))
 
-    print(f'gt_angles:{gt_angles}')
+    # print(f'gt_angles:{gt_angles}')
     print(f'gt_mask_params:{gt_mask_params}')
     print(f'gt_labels_all:{gt_labels_all}')
     print(f'arc_pos_matched_idxs:{arc_pos_matched_idxs}')
 
     gt_sel_params = []
     gt_sel_angles = []
-    for proposals_per_image, gt_angle, gt_params, gt_label, midx in zip(proposals, gt_angles, gt_mask_params,
+    for proposals_per_image, gt_ends, gt_params, gt_label, midx in zip(proposals, gt_mask_ends, gt_mask_params,
                                                                         gt_labels_all, arc_pos_matched_idxs):
         print(f'line_proposals_per_image:{proposals_per_image.shape}')
         # gt_angle = torch.tensor(gt_angle)
-        gt_angle = torch.stack(gt_angle, dim=0)
+        gt_ends = torch.tensor(gt_ends)
         gt_params = torch.tensor(gt_params)
-        if gt_angle.shape[0] > 0:
+        if gt_ends.shape[0] > 0:
             # positions = (gt_label == 3).nonzero()[0].item()
 
-            po = gt_angle[midx.cpu()]
+            po = gt_ends[midx.cpu()]
             pa = gt_params[midx.cpu()]
             print(f'po:{po},pa:{pa}')
 
@@ -1850,34 +1854,40 @@ def compute_arc_equation_loss(arc_equation, proposals, gt_mask_ends, gt_mask_par
 
     gt_sel_angles = torch.cat(gt_sel_angles, dim=0)
     gt_sel_params = torch.cat(gt_sel_params, dim=0)
-    pred_angles = arc_equation[:, 5:7]
+    pred_ends = arc_equation[:, 5:9]
     pred_params = arc_equation[:, :5]
 
-    print_params(pred_angles, pred_params, gt_sel_angles, gt_sel_params)
+    # print_params(pred_angles, pred_params, gt_sel_angles, gt_sel_params)
 
-    pred_sin = torch.sin(pred_angles)
-    pred_cos = torch.cos(pred_angles)
-    gt_sin = torch.sin(gt_sel_angles)
-    gt_cos = torch.cos(gt_sel_angles)
-    angle_loss = F.mse_loss(pred_sin, gt_sin) + F.mse_loss(pred_cos, gt_cos)
+    # pred_sin = torch.sin(pred_angles)
+    # pred_cos = torch.cos(pred_angles)
+    # gt_sin = torch.sin(gt_sel_angles)
+    # gt_cos = torch.cos(gt_sel_angles)
+    # angle_loss = F.mse_loss(pred_sin, gt_sin) + F.mse_loss(pred_cos, gt_cos)
 
 
     param_loss = F.mse_loss(pred_params, gt_sel_params) / 10000
+    print("start")
+    print_params(pred_ends, gt_sel_angles)
+    pred_ends = pred_ends.view(-1, 2, 2)
+    print("end")
+    print_params(pred_ends, gt_sel_angles)
+    ends_loss = F.mse_loss(pred_ends, gt_sel_angles) / 10000
 
-    print(f'angle_loss:{angle_loss.item()}, param_loss:{param_loss.item()}')
+    # print(f'angle_loss:{angle_loss.item()}, param_loss:{param_loss.item()}')
 
 
     count = sum(len(sublist) for sublist in proposals)
-    total_loss = ((param_loss + angle_loss) / count) if count > 0 else torch.tensor(0.0, device=device,
+    total_loss = ((param_loss + ends_loss) / count) if count > 0 else torch.tensor(0.0, device=device,
                                                                                     dtype=torch.float)
 
     total_loss = total_loss.to(device)
-    angle_loss = angle_loss.to(device)
+    ends_loss = ends_loss.to(device)
     param_loss = param_loss.to(device)
 
-    print(f'total_loss, param_loss, angle_loss: {total_loss.item()}, {param_loss.item()}, {angle_loss.item()}')
+    # print(f'total_loss, param_loss, angle_loss: {total_loss.item()}, {param_loss.item()}, {angle_loss.item()}')
 
-    return total_loss, param_loss, angle_loss
+    return total_loss, param_loss, ends_loss
 
 
     # angle_loss = F.mse_loss(pred_angles, gt_sel_angles)

+ 3 - 2
models/line_detect/train.yaml

@@ -7,7 +7,8 @@ io:
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-  datadir: /data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask
+#  datadir: /data/share/zyh/master_dataset/dataset_net/pokou_251115/a_dataset_pokou_mask
+  datadir: /data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb
@@ -20,7 +21,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 1
+  batch_size: 2
   max_epoch: 8000000
 #  augmentation: True
   augmentation: False

+ 78 - 28
models/line_detect/trainer.py

@@ -23,6 +23,8 @@ from tools import utils
 
 import matplotlib as mpl
 
+from utils.data_process.show_prams import print_params
+
 cmap = plt.get_cmap("jet")
 norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
 sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
@@ -153,41 +155,42 @@ def fit_circle(points):
     return (cx, cy), r
 from PIL import ImageDraw, Image
 import io
-# 绘制椭圆
+
 def draw_el(all, background_img):
-    # 解析椭圆参数
+    """
+    all = [x_center, y_center, a, b, theta, x1, y1, x2, y2]
+        theta: ellipse rotation (degrees)
+        (x1, y1): start point
+        (x2, y2): end point
+    """
+
     if isinstance(all, torch.Tensor):
         all = all.cpu().numpy()
-    x, y, a, b, q, q1, q2 = all
-    theta = np.radians(q)
-    phi1 = np.radians(q1)  # 第一个点的参数角
-    phi2 = np.radians(q2)  # 第二个点的参数角
-
-    # 生成椭圆上的点
-    phi = np.linspace(0, 2 * np.pi, 500)
-    x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
-    y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
-
-    # 计算两个指定点的坐标
-    def param_to_point(phi, xc, yc, a, b, theta):
-        x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
-        y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
-        return x, y
-
-    P1 = param_to_point(phi1, x, y, a, b, theta)
-    P2 = param_to_point(phi2, x, y, a, b, theta)
-
-    # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C])
+
+    # Unpack parameters
+    cx, cy, a, b, theta_deg, x1, y1, x2, y2 = all
+    theta = np.radians(theta_deg)
+
+    # ====== Draw ellipse ======
+    phi = np.linspace(0, np.pi * 2, 500)
+    x_ellipse = cx + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
+    y_ellipse = cy + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
+
+    # ====== Draw image ======
     plt.figure(figsize=(10, 10))
-    plt.imshow(background_img)  # 直接显示背景图
+    plt.imshow(background_img)
 
-    # 绘制椭圆及相关元素
+    # Ellipse
     plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
-    plt.plot(x, y, 'ko', markersize=8)
-    plt.plot(P1[0], P1[1], 'ro', markersize=10)
-    plt.plot(P2[0], P2[1], 'go', markersize=10)
 
-    # 转换为TensorBoard所需的张量格式 [C, H, W]
+    # Center
+    plt.plot(cx, cy, 'ko', markersize=8)
+
+    # Start & End points (now real coordinates)
+    plt.plot(x1, y1, 'ro', markersize=10)
+    plt.plot(x2, y2, 'go', markersize=10)
+
+    # ====== Convert to tensor ======
     buf = io.BytesIO()
     plt.savefig(buf, format='png', bbox_inches='tight')
     buf.seek(0)
@@ -196,6 +199,53 @@ def draw_el(all, background_img):
     plt.close()
 
     return img_tensor
+
+# from PIL import ImageDraw, Image
+# import io
+# # 绘制椭圆
+# def draw_el(all, background_img):
+#     # 解析椭圆参数
+#     if isinstance(all, torch.Tensor):
+#         all = all.cpu().numpy()
+#     print_params(all)
+#     x, y, a, b, q, q1, q2 = all
+#     theta = np.radians(q)
+#     phi1 = np.radians(q1)  # 第一个点的参数角
+#     phi2 = np.radians(q2)  # 第二个点的参数角
+#
+#     # 生成椭圆上的点
+#     phi = np.linspace(0, 2 * np.pi, 500)
+#     x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
+#     y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
+#
+#     # 计算两个指定点的坐标
+#     def param_to_point(phi, xc, yc, a, b, theta):
+#         x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
+#         y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
+#         return x, y
+#
+#     P1 = param_to_point(phi1, x, y, a, b, theta)
+#     P2 = param_to_point(phi2, x, y, a, b, theta)
+#
+#     # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C])
+#     plt.figure(figsize=(10, 10))
+#     plt.imshow(background_img)  # 直接显示背景图
+#
+#     # 绘制椭圆及相关元素
+#     plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
+#     plt.plot(x, y, 'ko', markersize=8)
+#     plt.plot(P1[0], P1[1], 'ro', markersize=10)
+#     plt.plot(P2[0], P2[1], 'go', markersize=10)
+
+    # 转换为TensorBoard所需的张量格式 [C, H, W]
+    # buf = io.BytesIO()
+    # plt.savefig(buf, format='png', bbox_inches='tight')
+    # buf.seek(0)
+    # result_img = Image.open(buf).convert('RGB')
+    # img_tensor = torch.from_numpy(np.array(result_img)).permute(2, 0, 1)
+    # plt.close()
+    #
+    # return img_tensor
 # 由低到高蓝黄红
 def draw_lines_with_scores(tensor_image, lines, scores, width=3, cmap='viridis'):
     """

+ 91 - 63
utils/data_process/csv_circle/csv_read.py

@@ -6,164 +6,192 @@ import math
 from typing import List, Union, Dict
 
 # === 文件夹配置 ===
-csv_folder = r"\\192.168.50.222\share\zyh\master_dataset\pokou\remark_251104\params"  # CSV 文件夹
-json_folder = r"\\192.168.50.222\share\zyh\master_dataset\pokou\total"  # JSON 和图片文件夹
-output_folder = r"\\192.168.50.222\share\zyh\master_dataset\pokou\251115\csvjson"  # 输出文件夹
+csv_folder = r"/data/share/zyh/master_dataset/pokou/merge/251121_251115/csv"     # CSV 文件夹
+json_folder_json = r"/data/share/zyh/master_dataset/pokou/merge/251121_251115/json"         # JSON 文件夹
+json_folder_img  = r"/data/share/zyh/master_dataset/pokou/merge/251121_251115/image"       # 图片文件夹
+output_folder = r"/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/to_dataset"        # 输出文件夹
 
 os.makedirs(output_folder, exist_ok=True)
 
+
+# ==============================================================
+# 计算圆弧端点
+# ==============================================================
 def compute_arc_ends(points: List[List[float]]) -> List[List[float]]:
     if len(points) != 3:
-        return [[0,0],[0,0]]
+        return [[0, 0], [0, 0]]
 
     p1, p2, p3 = points
     x1, y1 = p1
     x2, y2 = p2
     x3, y3 = p3
 
-    # --- 圆心 ---
-    A = 2*(x2-x1)
-    B = 2*(y2-y1)
-    C = x2**2+y2**2-x1**2-y1**2
-    D = 2*(x3-x2)
-    E = 2*(y3-y2)
-    F = x3**2+y3**2-x2**2-y2**2
+    A = 2 * (x2 - x1)
+    B = 2 * (y2 - y1)
+    C = x2**2 + y2**2 - x1**2 - y1**2
+    D = 2 * (x3 - x2)
+    E = 2 * (y3 - y2)
+    F = x3**2 + y3**2 - x2**2 - y2**2
 
-    denom = A*E - B*D
+    denom = A * E - B * D
     if denom == 0:
         return [p1, p3]
 
-    cx = (C*E - F*B)/denom
-    cy = (A*F - D*C)/denom
+    cx = (C * E - F * B) / denom
+    cy = (A * F - D * C) / denom
 
-    angles = [math.atan2(y-cy, x-cx) for x,y in points]
+    angles = [math.atan2(y - cy, x - cx) for x, y in points]
 
-    def angle_diff(a1,a2):
-        diff = (a2-a1)%(2*math.pi)
-        if diff>math.pi: diff=2*math.pi-diff
+    def angle_diff(a1, a2):
+        diff = (a2 - a1) % (2 * math.pi)
+        if diff > math.pi:
+            diff = 2 * math.pi - diff
         return diff
 
-    pairs = [(0,1),(0,2),(1,2)]
-    max_diff=-1
-    end_pair=(0,1)
-    for i,j in pairs:
-        diff=angle_diff(angles[i],angles[j])
-        if diff>max_diff:
-            max_diff=diff
-            end_pair=(i,j)
+    pairs = [(0, 1), (0, 2), (1, 2)]
+    max_diff = -1
+    end_pair = (0, 1)
+
+    for i, j in pairs:
+        diff = angle_diff(angles[i], angles[j])
+        if diff > max_diff:
+            max_diff = diff
+            end_pair = (i, j)
 
     return [points[end_pair[0]], points[end_pair[1]]]
 
 
-# === 工具函数:匹配点到最近椭圆 ===
+# ==============================================================
+# 根据点匹配到最近椭圆
+# ==============================================================
 def match_point_to_ellipse(point: List[float], ellipses: List[Dict]) -> int:
-    """
-    根据点匹配到最近的椭圆
-    point: [x,y]
-    ellipses: list of dict,每个包含 cx,cy
-    return: ellipse_index
-    """
     x, y = point
-    min_dist = float('inf')
+    min_dist = float("inf")
     match_idx = -1
+
     for i, e in enumerate(ellipses):
-        cx, cy = e['cx'], e['cy']
+        cx, cy = e["cx"], e["cy"]
         dist = math.hypot(x - cx, y - cy)
         if dist < min_dist:
             min_dist = dist
             match_idx = i
+
     return match_idx
 
-# === 读取 CSV,构建文件名到椭圆参数映射 ===
+
+# ==============================================================
+# 从 CSV 读取椭圆参数映射
+# ==============================================================
 csv_ellipse_map = {}  # filename -> list of ellipse params
+
 for csv_file in os.listdir(csv_folder):
     if not csv_file.endswith(".csv"):
         continue
+
     csv_path = os.path.join(csv_folder, csv_file)
+
     with open(csv_path, "r", encoding="utf-8-sig") as f:
         reader = csv.DictReader(f)
         for row in reader:
             filename = row["filename"].strip()
             shape_str = row["region_shape_attributes"]
+
             try:
                 shape_data = json.loads(shape_str)
             except json.JSONDecodeError:
                 shape_data = json.loads(shape_str.replace('""', '"'))
+
             if filename not in csv_ellipse_map:
                 csv_ellipse_map[filename] = []
+
             csv_ellipse_map[filename].append(shape_data)
 
-# === 遍历 JSON 文件 ===
-for json_file in os.listdir(json_folder):
+
+# ==============================================================
+# 遍历 JSON 文件
+# ==============================================================
+for json_file in os.listdir(json_folder_json):
     if not json_file.endswith(".json"):
         continue
-    json_path = os.path.join(json_folder, json_file)
-    filename = json_file.replace(".json", ".jpg")
-    img_path = os.path.join(json_folder, filename)
 
+    json_path = os.path.join(json_folder_json, json_file)
+    filename = json_file.replace(".json", ".jpg")      # 图片的名字
+    img_path = os.path.join(json_folder_img, filename) # 图片从独立文件夹读取
+
+    # 图片存在性检查
     if not os.path.exists(img_path):
-        print(f"?? Image not found for {filename}")
+        print(f"[WARN] Image not found for: {filename}")
         continue
+
+    # CSV 中必须有匹配的记录
     if filename not in csv_ellipse_map:
-        print(f"?? CSV ellipse not found for {filename}")
+        print(f"[WARN] No CSV ellipse for: {filename}")
         continue
 
+    # 读取 JSON
     with open(json_path, "r", encoding="utf-8") as jf:
         data = json.load(jf)
 
     if "shapes" not in data:
         data["shapes"] = []
 
-    # 收集所有 arc 单点
-    arc_points = [s["points"][0] for s in data["shapes"]
-                  if s.get("label")=="arc" and "points" in s and len(s["points"])==1]
+    # 获取 JSON 中的单点 arc 标注
+    arc_points = [
+        s["points"][0]
+        for s in data["shapes"]
+        if s.get("label") == "arc" and "points" in s and len(s["points"]) == 1
+    ]
 
-    # 根据 CSV 椭圆匹配分组
+    # 从 CSV 获取椭圆信息
     ellipses = csv_ellipse_map[filename]
     ellipse_point_map = {i: [] for i in range(len(ellipses))}
+
+    # 将 arc 点匹配到最近的椭圆
     for pt in arc_points:
         idx = match_point_to_ellipse(pt, ellipses)
         ellipse_point_map[idx].append(pt)
 
-    # 打印匹配到两个及以上椭圆的图片信息
-    active_ellipses = [i for i, pts in ellipse_point_map.items() if len(pts) >= 1]
-    if len(active_ellipses) >= 2:
-        print(f"Image {filename} matches {len(active_ellipses)} ellipses")
-
-    # 构造新的 arc_shape
+    # 生成新的 arc shapes
     new_arc_shapes = []
     for idx, pts in ellipse_point_map.items():
         if len(pts) != 3:
-            print(f"?? {filename} ellipse {idx} points not equal 3, got {len(pts)}")
-            ends = [[0,0],[0,0]]
+            print(f"[WARN] {filename} ellipse {idx} has {len(pts)} points (expected 3)")
+            ends = [[0, 0], [0, 0]]
         else:
             ends = compute_arc_ends(pts)
+
         e = ellipses[idx]
         arc_shape = {
             "label": "arc",
             "points": pts,
-            "params": [e.get("cx",0), e.get("cy",0), e.get("rx",0), e.get("ry",0), e.get("theta",0)],
+            "params": [
+                e.get("cx", 0),
+                e.get("cy", 0),
+                e.get("rx", 0),
+                e.get("ry", 0),
+                e.get("theta", 0),
+            ],
             "ends": ends,
             "group_id": None,
             "description": "",
             "difficult": False,
             "shape_type": "arc",
             "flags": {},
-            "attributes": {}
+            "attributes": {},
         }
         new_arc_shapes.append(arc_shape)
 
-    # 保留非 arc shapes
-    remaining_shapes = [s for s in data["shapes"] if s.get("label") != "arc"]
-    data["shapes"] = remaining_shapes + new_arc_shapes
+    # 删除旧 arc,添加新 arc
+    remaining = [s for s in data["shapes"] if s.get("label") != "arc"]
+    data["shapes"] = remaining + new_arc_shapes
 
-    # 保存 JSON
+    # 输出 JSON
     output_json = os.path.join(output_folder, json_file)
     with open(output_json, "w", encoding="utf-8") as jf:
         json.dump(data, jf, ensure_ascii=False, indent=2)
 
     # 复制图片
     shutil.copy2(img_path, os.path.join(output_folder, filename))
-    print(f"Saved merged JSON and image for: {filename}")
+    print(f"[OK] Saved merged data for: {filename}")
 
-print("\nAll done! Final JSONs and images saved in:", output_folder)
+print("\nAll done! Output in:", output_folder)

+ 1 - 1
utils/data_process/csv_circle/csv_show.py

@@ -5,7 +5,7 @@ import numpy as np
 import math
 
 # === 配置 ===
-input_folder = r"/data/share/zyh/master_dataset/pokou/251115/output"  # 输入 JSON 和图片
+input_folder = r"/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/resize"  # 输入 JSON 和图片
 output_folder = os.path.join(os.path.dirname(input_folder), "show")
 os.makedirs(output_folder, exist_ok=True)
 

+ 2 - 2
utils/data_process/csv_circle/resice.py

@@ -3,8 +3,8 @@ from PIL import Image
 import shutil
 
 # ÊäÈë¡¢Êä³öÎļþ¼Ð
-input_folder = "/data/share/zyh/master_dataset/pokou/251115/csvjson"
-output_folder = "/data/share/zyh/master_dataset/pokou/251115/output"
+input_folder = "/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/to_dataset"
+output_folder = "/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/resize"
 os.makedirs(output_folder, exist_ok=True)
 
 # ±éÀúÊäÈëÎļþ¼Ð

+ 2 - 2
utils/data_process/d_data_spliter.py

@@ -80,10 +80,10 @@ def organize_data(
 
 if __name__ == "__main__":
     # 输入输出目录(可修改)
-    source_dir = r"/data/share/zyh/master_dataset/pokou/251115/output"
+    source_dir = r"/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/resize"
 
     parent_dir = os.path.dirname(source_dir)
-    output_dir = os.path.join(parent_dir, "a_dataset_pokou_mask")
+    output_dir = os.path.join(parent_dir, "a_dataset")
 
     # 后缀名列表,方便以后扩展其他格式
     image_exts = ['.tiff','.jpg','.png']