Pārlūkot izejas kodu

save local changes

zhaoyinghan 1 mēnesi atpakaļ
vecāks
revīzija
2f83343537

+ 1 - 1
models/base/base_dataset.py

@@ -16,7 +16,7 @@ class BaseDataset(Dataset, ABC):
         pass
 
     @abstractmethod
-    def read_target(self,item,lbl_path,extra=None):
+    def read_target(self,item,lbl_path,extra=None,image=None):
         pass
 
     """显示数据集指定图片"""

+ 2 - 2
models/line_detect/heads/head_losses.py

@@ -861,8 +861,8 @@ def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
         # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
 
         # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
-        save_full_mask(line_logits,"line_logits",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
-        save_full_mask(gs_heatmaps,"gs_heatmaps",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
+        # save_full_mask(line_logits,"line_logits",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
+        # save_full_mask(gs_heatmaps,"gs_heatmaps",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
 
         line_loss=combined_loss(line_logits, gs_heatmaps)
 

+ 15 - 10
models/line_detect/line_dataset.py

@@ -71,7 +71,7 @@ class LineDataset(BaseDataset):
             img = PIL.Image.open(img_path).convert('RGB')
             w, h = img.size
         # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
-        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
+        target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
 
         self.transforms = get_transforms(augmention=self.augmentation)
 
@@ -82,7 +82,7 @@ class LineDataset(BaseDataset):
     def __len__(self):
         return len(self.imgs)
 
-    def read_target(self, item, lbl_path, shape, extra=None):
+    def read_target(self, item, lbl_path, shape,extra=None,image=None):
         # print(f'shape:{shape}')
         # print(f'lbl_path:{lbl_path}')
         with open(lbl_path, 'r') as file:
@@ -122,26 +122,30 @@ class LineDataset(BaseDataset):
             # print_params(arc_angles)
             arc_masks = []
 
-
-
             for i in range(len(arc_params)):
                 arc_param_i = arc_params[i].view(-1)  # shape (5,)
                 arc_angle_i = arc_angles[i].view(-1)  # shape (2,)
                 arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0)  # shape (7,)
 
 
-                # print_params(arc7)
                 mask = arc_to_mask(arc7, shape, line_width=1)
 
                 arc_masks.append(mask)
-                # arc7=arc_params[i] + arc_angles[i].tolist()
-                # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
 
-            # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
+            print_params(arc_masks,image)
 
             target['circle_masks'] = torch.stack(arc_masks, dim=0)
-            save_full_mask(target['circle_masks'], "arc_masks",
-                           "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
+
+            # for i, m in enumerate(target['circle_masks']):
+            #     save_full_mask(
+            #         m,
+            #         name=f"arc_mask_{i}",
+            #         out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
+            #         save_png=True,
+            #         save_npy=True,
+            #         image=image,
+            #         show_on_image=True
+            #     )
 
 
 
@@ -277,6 +281,7 @@ def arc_to_mask(arc7, shape, line_width=1):
     # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
     if torch.all(arc7 == 0):
         return torch.zeros(shape, dtype=torch.uint8)
+    print_params(arc7)
 
     xc, yc, a, b, theta, phi1, phi2 = arc7
     H, W = shape

+ 55 - 0
utils/data_process/mask/arc7tomask.py

@@ -0,0 +1,55 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+def draw_ellipse_arc(xc, yc, a, b, theta, phi1, phi2, num_points=200):
+    """
+    »æÖÆÍÖÔ²»¡
+
+    Args:
+        xc, yc (float): ÍÖÔ²ÖÐÐÄ×ø±ê
+        a, b (float): ³¤¶Ì°ëÖá (a >= b)
+        theta (float): ÍÖÔ²ÄæÊ±ÕëÐýת½Ç¶È£¨»¡¶È£©
+        phi1, phi2 (float): »¡Ï߯ðʼ½Ç£¨»¡¶È£©£¬Çø¼ä [0, 2¦Ð)
+        num_points (int): »æÖƾ«¶È
+    """
+
+    # Éú³É»¡Ïß²ÎÊý
+    phi = np.linspace(phi1, phi2, num_points)
+
+    # ÍÖÔ²±ê×¼·½³ÌÉϵĵã
+    x = a * np.cos(phi)
+    y = b * np.sin(phi)
+
+    # Ðýת¾ØÕó
+    cos_t = np.cos(theta)
+    sin_t = np.sin(theta)
+
+    xr = x * cos_t - y * sin_t + xc
+    yr = x * sin_t + y * cos_t + yc
+
+    # »æÍ¼
+    plt.figure(figsize=(6,6))
+    plt.plot(xr, yr, 'r-', linewidth=2)
+    plt.scatter([xc], [yc], color='blue', label='Center')
+
+    plt.axis('equal')
+    plt.title("Ellipse Arc")
+    plt.grid(True)
+    plt.show()
+
+
+if __name__ == "__main__":
+    """
+    Tensor(shape=(7,), dtype=torch.float32, device=cpu, values=[1025.0, 560.0, 131.87399291992188, 116.3759994506836, -0.5879999995231628, 2.1936283111572266, 3.825697898864746])Tensor(shape=(7,), 
+    dtype=torch.float32, device=cpu, values=[721.0, 527.0, 121.36299896240234, 134.70399475097656, -0.02500000037252903, 2.0525598526000977, 4.132845878601074])
+    111values=[1022.0, 443.0, 88.08300018310547, 118.55699920654297, 2.7829999923706055, 5.329736709594727, 0.142518550157547])
+    values=[1123.0, 923.0, 158.6739959716797, 182.2220001220703, 0.6779999732971191, 0.8875003457069397, 2.6391656398773193])
+    Tensor(shape=(7,), dtype=torch.float32, device=cpu, values=[618.0, 542.0, 128.01300048828125, 141.08799743652344, 0.37400001287460327, 1.150931477546692, 5.289559364318848])\
+    Tensor(shape=(7,), dtype=torch.float32, device=cpu, values=[850.0, 365.0, 159.79400634765625, 121.38099670410156, -3.078000068664551, 4.8650922775268555, 5.932548522949219])
+    """
+    # ʾÀý£º´Ó tensor ÖлñÈ¡µÄÖµ
+    xc, yc, a, b, theta, phi1, phi2 = [
+        1022.0, 443.0, 88.08300018310547, 118.55699920654297, 2.7829999923706055, 5.329736709594727, 0.142518550157547
+    ]
+
+    draw_ellipse_arc(xc, yc, a, b, theta, phi1, phi2)

+ 6 - 0
utils/data_process/mask/npy.py

@@ -0,0 +1,6 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+arr = np.load(r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset/arc_masks_5_000.npy")
+plt.imshow(arr)
+plt.show()

+ 94 - 16
utils/data_process/mask/show_mask.py

@@ -4,32 +4,61 @@ import numpy as np
 import matplotlib.pyplot as plt
 from glob import glob
 
-def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
-                   save_png=True, save_txt=True, save_npy=False,
-                   max_files=100,
-                   save_all_zero=False):
-
+def save_full_mask(
+    mask: torch.Tensor,
+    name="mask",
+    out_dir="output",
+    save_png=True,
+    save_txt=True,
+    save_npy=False,
+    max_files=100,
+    save_all_zero=False,
+    zero_eps=1e-6,
+    # ÐÂÔö²ÎÊý
+    image=None,            # PIL.Image »ò numpy(H,W,3)
+    show_on_image=False,   # ÊÇ·ñ±£´æµþ¼ÓµÄͼ
+    alpha=0.4              # mask͸Ã÷¶È
+):
+    """
+    ±£´æÕûÕÅÌØÕ÷ͼ£¬¿ÉÑ¡ÔñÉú³Éµþ¼ÓԭͼºóµÄ¿ÉÊÓ»¯Îļþ
+    """
 
+    # È¥µô batch/channel ¶àÓàά¶È
     if mask.dim() == 4:
         mask = mask[0, 0]
     elif mask.dim() == 3:
         mask = mask[0]
 
-    if not save_all_zero and torch.all(mask == 0):
-        print(f"? {name} È« 0£¬ÒÑÌø¹ý±£´æ")
-        return
+    # ---------------- È« 0 ÅжϺ¯Êý ----------------
+    def is_all_zero(t: torch.Tensor):
+        if t.dtype.is_floating_point:
+            return torch.all(torch.abs(t) < zero_eps)
+        else:
+            return torch.all(t == 0)
+
+    # È« 0 Çé¿ö´¦Àí
+    if is_all_zero(mask):
+        print(f"? {name} È« 0")
+        if not save_all_zero:
+            print(f"¡ú δ±£´æ£¨save_all_zero=False£©")
+            return
+        else:
+            print(f"¡ú °´ÉèÖüÌÐø±£´æ£¨save_all_zero=True£©")
 
     os.makedirs(out_dir, exist_ok=True)
     mask_cpu = mask.detach().cpu()
 
-    print(f"\nSaving full mask: {name}")
-    print(f"shape = {tuple(mask.shape)}, dtype = {mask.dtype}, device = {mask.device}")
-
+    # ------------------- ·ÖÅäÐòºÅ -------------------
     pattern = os.path.join(out_dir, f"{name}_*.png")
     existing_files = sorted(glob(pattern))
     if existing_files:
-        existing_idxs = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in existing_files]
-        next_idx = max(existing_idxs) + 1
+        existing_idx = []
+        for f in existing_files:
+            try:
+                existing_idx.append(int(os.path.basename(f).split("_")[-1].split(".")[0]))
+            except:
+                pass
+        next_idx = max(existing_idx) + 1
     else:
         next_idx = 0
 
@@ -38,6 +67,11 @@ def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
 
     file_idx_str = f"{next_idx:03d}"
 
+    # ==========================================================
+    # ¢Ù ±£´æ mask ±¾Éí
+    # ==========================================================
+
+    # ¡ª¡ª¡ª ±£´æ PNG ¡ª¡ª¡ª
     if save_png:
         png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
         plt.figure()
@@ -46,14 +80,58 @@ def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
         plt.title(name)
         plt.savefig(png_path, dpi=300)
         plt.close()
-        print(f"? saved PNG -> {png_path}")
+        print(f"?? saved PNG -> {png_path}")
 
+    # ¡ª¡ª¡ª ±£´æ TXT ¡ª¡ª¡ª
     if save_txt:
         txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
         np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
-        print(f"? saved TXT -> {txt_path}")
+        print(f"?? saved TXT -> {txt_path}")
 
+    # ¡ª¡ª¡ª ±£´æ NPY ¡ª¡ª¡ª
     if save_npy:
         npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
         np.save(npy_path, mask_cpu.numpy())
-        print(f"? saved NPY -> {npy_path}")
+        print(f"?? saved NPY -> {npy_path}")
+
+    # ==========================================================
+    # ¢Ú ±£´æµþ¼ÓЧ¹û£¨²»ÏÔʾ£¬Ö»±£´æ£©
+    # ==========================================================
+    if show_on_image:
+        if image is None:
+            print("? show_on_image=True µ«Î´´«Èë image£¬Ìø¹ýµþ¼Ó±£´æ")
+            return
+
+        # Èç¹ûÈ« 0£¬¾Í²»×öµþ¼Ó±£´æ
+        if is_all_zero(mask):
+            print(f"? {name} È« 0 ¡ú ²»Éú³Éµþ¼Ó overlay")
+            return
+
+        # image ¡ú numpy
+        if hasattr(image, "size"):     # PIL.Image
+            img_np = np.array(image)
+        else:                           # numpy array
+            img_np = image
+
+        mask_np = mask_cpu.numpy()
+        if mask_np.dtype != np.uint8:
+            mask_np = (mask_np > 0).astype(np.uint8)
+
+        # ºìÉ«°ë͸Ã÷µþ¼Ó
+        colored_mask = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
+        colored_mask[..., 0] = mask_np * 255  # R channel
+
+        overlay_path = os.path.join(out_dir, f"{name}_{file_idx_str}_overlay.png")
+
+        plt.figure(figsize=(8, 8))
+        plt.imshow(img_np)
+        plt.imshow(colored_mask, alpha=alpha)
+        plt.title(f"{name} overlay")
+        plt.title(f"{name} overlay")
+        plt.axis('off')
+
+        # ? ²»µ÷Óà plt.show()£¬Ö±½Ó±£´æ
+        plt.savefig(overlay_path, dpi=300)
+        plt.close()
+
+        print(f"?? saved overlay -> {overlay_path}")