Browse Source

fix arc dataset

zhaoyinghan 4 weeks ago
parent
commit
3b98c3f1dc
2 changed files with 171 additions and 69 deletions
  1. 115 21
      models/line_detect/line_dataset.py
  2. 56 48
      utils/data_process/mask/show_mask.py

+ 115 - 21
models/line_detect/line_dataset.py

@@ -71,7 +71,7 @@ class LineDataset(BaseDataset):
             img = PIL.Image.open(img_path).convert('RGB')
             w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
 
         self.transforms = get_transforms(augmention=self.augmentation)
 
@@ -82,7 +82,7 @@ class LineDataset(BaseDataset):
     def __len__(self):
         return len(self.imgs)
 
-    def read_target(self, item, lbl_path, shape, extra=None):
+    def read_target(self, item, lbl_path, shape, extra=None,image=None):
         # print(f'shape:{shape}')
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
@@ -118,30 +118,21 @@ class LineDataset(BaseDataset):
             target['mask_params'] = arc_params
 
 
-            arc_angles = compute_arc_angles(arc_ends, arc_params)
-            # print_params(arc_angles)
-            arc_masks = []
 
+            # arc_angles = compute_arc_angles(arc_ends, arc_params)
 
 
+            print_params(arc_ends,arc_params)
+            arc_masks = []
             for i in range(len(arc_params)):
-                arc_param_i = arc_params[i].view(-1)  # shape (5,)
-                arc_angle_i = arc_angles[i].view(-1)  # shape (2,)
-                arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0)  # shape (7,)
-
-
-                # print_params(arc7)
-                mask = arc_to_mask(arc7, shape, line_width=1)
-
+                mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000))
                 arc_masks.append(mask)
-                # arc7=arc_params[i] + arc_angles[i].tolist()
-                # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
-
-            # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
-
+            print_params(arc_masks)
             target['circle_masks'] = torch.stack(arc_masks, dim=0)
-            save_full_mask(target['circle_masks'], "arc_masks",
-                           "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
+
+            # save_full_mask(torch.stack(arc_masks, dim=0), "arc_masks",
+            #                "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
+            #                force_save=False,image=image,show_on_image=True)
 
 
 
@@ -249,6 +240,109 @@ class LineDataset(BaseDataset):
         pass
 
 
+import torch
+import numpy as np
+import cv2
+
+def arc_to_mask_safe(arc_param, arc_end, shape, line_width=5, debug=True, idx=-1):
+    """
+    Generate a mask for a small (<180 degree) arc based on arc parameters and endpoints.
+
+    Args:
+        arc_param: torch.Tensor of shape (5,)  - [cx, cy, a, b, theta]
+        arc_end: torch.Tensor of shape (2,2)  - [[x1,y1],[x2,y2]]
+        shape: tuple (H,W) - mask size
+        line_width: thickness of the arc
+        debug: bool - if True, print debug info
+        idx: int or str - index for debugging identification
+
+    Returns:
+        mask: torch.Tensor of shape (H,W)
+    """
+
+    # ------------------ Check for all-zero input ------------------
+    if torch.all(arc_param == 0) or torch.all(arc_end == 0):
+        if debug:
+            print(f"[{idx}] Warning: arc_param or arc_end all zeros. Returning zero mask.")
+            print(f"[{idx}] arc_param: {arc_param.tolist()}")
+            print(f"[{idx}] arc_end: {arc_end.tolist()}")
+        return torch.zeros(shape, dtype=torch.float32)
+
+    cx, cy, a, b, theta = arc_param.tolist()
+
+    if a <= 0 or b <= 0:
+        if debug:
+            print(f"[{idx}] Warning: invalid ellipse axes a={a}, b={b}. Returning zero mask.")
+            print(f"[{idx}] arc_param: {arc_param.tolist()}")
+            print(f"[{idx}] arc_end: {arc_end.tolist()}")
+        return torch.zeros(shape, dtype=torch.float32)
+
+    x1, y1 = arc_end[0].tolist()
+    x2, y2 = arc_end[1].tolist()
+
+    cos_t = np.cos(theta)
+    sin_t = np.sin(theta)
+
+    def point_to_angle(x, y):
+        dx = x - cx
+        dy = y - cy
+        x_ = cos_t * dx + sin_t * dy
+        y_ = -sin_t * dx + cos_t * dy
+        return np.arctan2(y_ / b, x_ / a)
+
+    try:
+        angle1 = point_to_angle(x1, y1)
+        angle2 = point_to_angle(x2, y2)
+    except Exception as e:
+        if debug:
+            print(f"[{idx}] Exception in point_to_angle: {e}")
+            print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}")
+        return torch.zeros(shape, dtype=torch.float32)
+
+    if np.isnan(angle1) or np.isnan(angle2):
+        if debug:
+            print(f"[{idx}] Warning: angle1 or angle2 is NaN. Returning zero mask.")
+            print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}")
+        return torch.zeros(shape, dtype=torch.float32)
+
+    # Ensure small arc (<180 degrees)
+    if angle2 < angle1:
+        angle2 += 2 * np.pi
+    if angle2 - angle1 > np.pi:
+        angle1, angle2 = angle2, angle1 + 2 * np.pi
+
+    angles = np.linspace(angle1, angle2, 100)
+    xs = cx + a * np.cos(angles) * cos_t - b * np.sin(angles) * sin_t
+    ys = cy + a * np.cos(angles) * sin_t + b * np.sin(angles) * cos_t
+
+    xs = np.nan_to_num(xs, nan=0.0).astype(np.int32)
+    ys = np.nan_to_num(ys, nan=0.0).astype(np.int32)
+
+    # ------------------ Debug prints ------------------
+    if debug:
+        print(f"[{idx}] arc_param: {arc_param.tolist()}")
+        print(f"[{idx}] arc_end: {arc_end.tolist()}")
+        print(f"[{idx}] xs[:5], ys[:5]: {xs[:5]}, {ys[:5]}")
+
+    mask = np.zeros(shape, dtype=np.uint8)
+    pts = np.stack([xs, ys], axis=1)
+
+    # Draw the arc with given line_width
+    for i in range(len(pts) - 1):
+        cv2.line(mask, tuple(pts[i]), tuple(pts[i + 1]), color=1, thickness=line_width)
+
+    # ------------------ Extra check for non-zero mask ------------------
+    if debug:
+        mask_sum = mask.sum()
+        if mask_sum == 0:
+            print(f"[{idx}] Warning: mask generated is all zeros!")
+        else:
+            print(f"[{idx}] mask sum: {mask_sum}")
+
+    return torch.tensor(mask, dtype=torch.float32)
+
+
+
 def draw_el(all):
     # 解析椭圆参数
     if isinstance(all, torch.Tensor):
@@ -585,6 +679,6 @@ def get_boxes_lines(objs, shape):
 
 
 if __name__ == '__main__':
-    path = r'\\192.168.50.222/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
+    path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
     dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
     dataset.show(19, show_type='arc_yuan_point_ellipse')

+ 56 - 48
utils/data_process/mask/show_mask.py

@@ -14,41 +14,62 @@ def save_full_mask(
     max_files=100,
     save_all_zero=False,
     zero_eps=1e-6,
-    # ÐÂÔö²ÎÊý
-    image=None,            # PIL.Image »ò numpy(H,W,3)
-    show_on_image=False,   # ÊÇ·ñ±£´æµþ¼ÓµÄͼ
-    alpha=0.4              # mask͸Ã÷¶È
+    # New parameters
+    image=None,            # PIL.Image or numpy array (H,W,3)
+    show_on_image=False,   # Whether to save overlay on original image
+    alpha=0.4,             # Overlay transparency
+    force_save=False       # Force save even if mask is all zeros
 ):
     """
-    ±£´æÕûÕÅÌØÕ÷ͼ£¬¿ÉÑ¡ÔñÉú³Éµþ¼ÓԭͼºóµÄ¿ÉÊÓ»¯Îļþ
+    Save a full mask tensor safely, optionally with overlay on an image.
+    Supports batch masks (shape: [N,H,W]) and saves each slice separately.
+    Force save if force_save=True even if the mask is all zeros.
     """
 
-    # È¥µô batch/channel ¶àÓàά¶È
-    if mask.dim() == 4:
-        mask = mask[0, 0]
-    elif mask.dim() == 3:
-        mask = mask[0]
+    # ---------------- Convert batch masks ----------------
+    if mask.dim() == 3:  # (N, H, W)
+        for idx in range(mask.shape[0]):
+            save_full_mask(
+                mask[idx],
+                name=f"{name}_{idx}",
+                out_dir=out_dir,
+                save_png=save_png,
+                save_txt=save_txt,
+                save_npy=save_npy,
+                max_files=max_files,
+                save_all_zero=save_all_zero,
+                zero_eps=zero_eps,
+                image=image,
+                show_on_image=show_on_image,
+                alpha=alpha,
+                force_save=force_save
+            )
+        return
+
+    elif mask.dim() > 3:
+        # Flatten batch/channel dimensions safely, take first image
+        mask = mask.view(-1, mask.shape[-2], mask.shape[-1])[0]
 
-    # ---------------- È« 0 ÅжϺ¯Êý ----------------
+    mask_cpu = mask.detach().cpu()
+
+    # ---------------- Check if mask is all zeros ----------------
     def is_all_zero(t: torch.Tensor):
         if t.dtype.is_floating_point:
             return torch.all(torch.abs(t) < zero_eps)
         else:
             return torch.all(t == 0)
 
-    # È« 0 Çé¿ö´¦Àí
-    if is_all_zero(mask):
-        print(f"? {name} È« 0")
-        if not save_all_zero:
-            print(f"¡ú δ±£´æ£¨save_all_zero=False£©")
-            return
-        else:
-            print(f"¡ú °´ÉèÖüÌÐø±£´æ£¨save_all_zero=True£©")
+    all_zero = is_all_zero(mask_cpu)
+    if all_zero and not save_all_zero and not force_save:
+        print(f"{name} all zeros ¡ú Not saved")
+        return
+    elif all_zero and force_save:
+        print(f"{name} all zeros ¡ú force saving")
 
+    # ---------------- Create output directory ----------------
     os.makedirs(out_dir, exist_ok=True)
-    mask_cpu = mask.detach().cpu()
 
-    # ------------------- ·ÖÅäÐòºÅ -------------------
+    # ---------------- Determine file index ----------------
     pattern = os.path.join(out_dir, f"{name}_*.png")
     existing_files = sorted(glob(pattern))
     if existing_files:
@@ -67,11 +88,7 @@ def save_full_mask(
 
     file_idx_str = f"{next_idx:03d}"
 
-    # ==========================================================
-    # ¢Ù ±£´æ mask ±¾Éí
-    # ==========================================================
-
-    # ¡ª¡ª¡ª ±£´æ PNG ¡ª¡ª¡ª
+    # ---------------- Save mask itself ----------------
     if save_png:
         png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
         plt.figure()
@@ -80,46 +97,41 @@ def save_full_mask(
         plt.title(name)
         plt.savefig(png_path, dpi=300)
         plt.close()
-        print(f"?? saved PNG -> {png_path}")
+        print(f"Saved PNG -> {png_path}")
 
-    # ¡ª¡ª¡ª ±£´æ TXT ¡ª¡ª¡ª
     if save_txt:
         txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
         np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
-        print(f"?? saved TXT -> {txt_path}")
+        print(f"Saved TXT -> {txt_path}")
 
-    # ¡ª¡ª¡ª ±£´æ NPY ¡ª¡ª¡ª
     if save_npy:
         npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
         np.save(npy_path, mask_cpu.numpy())
-        print(f"?? saved NPY -> {npy_path}")
+        print(f"Saved NPY -> {npy_path}")
 
-    # ==========================================================
-    # ¢Ú ±£´æµþ¼ÓЧ¹û£¨²»ÏÔʾ£¬Ö»±£´æ£©
-    # ==========================================================
+    # ---------------- Save overlay if requested ----------------
     if show_on_image:
         if image is None:
-            print("? show_on_image=True µ«Î´´«Èë image£¬Ìø¹ýµþ¼Ó±£´æ")
+            print("show_on_image=True but image not provided, skipping overlay")
             return
 
-        # Èç¹ûÈ« 0£¬¾Í²»×öµþ¼Ó±£´æ
-        if is_all_zero(mask):
-            print(f"? {name} È« 0 ¡ú ²»Éú³Éµþ¼Ó overlay")
+        if all_zero and not force_save:
+            print(f"{name} all zeros ¡ú skip overlay")
             return
 
-        # image ¡ú numpy
-        if hasattr(image, "size"):     # PIL.Image
+        # Convert image to numpy array
+        if hasattr(image, "size"):  # PIL.Image
             img_np = np.array(image)
-        else:                           # numpy array
+        else:  # numpy array
             img_np = image
 
         mask_np = mask_cpu.numpy()
         if mask_np.dtype != np.uint8:
             mask_np = (mask_np > 0).astype(np.uint8)
 
-        # ºìÉ«°ë͸Ã÷µþ¼Ó
+        # Red overlay
         colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
-        colored_mask[..., 0] = mask_np * 255  # R channel
+        colored_mask[..., 0] = mask_np * 255  # Red channel
 
         overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png")
 
@@ -127,11 +139,7 @@ def save_full_mask(
         plt.imshow(img_np)
         plt.imshow(colored_mask, alpha=alpha)
         plt.title(f"{name} overlay")
-        plt.title(f"{name} overlay")
         plt.axis('off')
-
-        # ? ²»µ÷Óà plt.show()£¬Ö±½Ó±£´æ
         plt.savefig(overlay_path, dpi=300)
         plt.close()
-
-        print(f"?? saved overlay -> {overlay_path}")
+        print(f"Saved overlay -> {overlay_path}")