zhaoyinghan преди 1 месец
родител
ревизия
3939e25f18
променени са 4 файла, в които са добавени 127 реда и са изтрити 38 реда
  1. 6 0
      models/line_detect/heads/head_losses.py
  2. 16 8
      models/line_detect/line_dataset.py
  3. 46 30
      models/line_detect/loi_heads.py
  4. 59 0
      utils/data_process/mask/show_mask.py

+ 6 - 0
models/line_detect/heads/head_losses.py

@@ -5,6 +5,8 @@ import torch.nn.functional as F
 from torch import nn
 from torch.cuda import device
 
+from utils.data_process.mask.show_mask import save_full_mask
+
 
 class DiceLoss(nn.Module):
     def __init__(self, smooth=1.):
@@ -859,6 +861,9 @@ def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
         # line_loss = F.binary_cross_entropy_with_logits(line_logits, gs_heatmaps)
 
         # line_loss = F.cross_entropy(line_logits, gs_heatmaps)
+        save_full_mask(line_logits,"line_logits",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
+        save_full_mask(gs_heatmaps,"gs_heatmaps",out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_loss")
+
         line_loss=combined_loss(line_logits, gs_heatmaps)
 
     else:
@@ -869,6 +874,7 @@ def compute_ins_loss(feature_logits, proposals, gt_, pos_matched_idxs):
     return line_loss
 
 
+
 def align_masks(keypoints, rois, heatmap_size):
     print(f'rois:{rois.shape}')
     print(f'heatmap_size:{heatmap_size}')

+ 16 - 8
models/line_detect/line_dataset.py

@@ -17,6 +17,7 @@ import torch
 
 import matplotlib.pyplot as plt
 from models.base.transforms import get_transforms
+from utils.data_process.mask.show_mask import save_full_mask
 from utils.data_process.show_prams import print_params
 
 
@@ -129,7 +130,7 @@ class LineDataset(BaseDataset):
                 arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0)  # shape (7,)
 
 
-                print_params(arc7)
+                # print_params(arc7)
                 mask = arc_to_mask(arc7, shape, line_width=1)
 
                 arc_masks.append(mask)
@@ -137,7 +138,11 @@ class LineDataset(BaseDataset):
                 # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
 
             # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
+
             target['circle_masks'] = torch.stack(arc_masks, dim=0)
+            save_full_mask(target['circle_masks'], "arc_masks",
+                           "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
+
 
 
 
@@ -158,6 +163,8 @@ class LineDataset(BaseDataset):
         # validate_keypoints(lines, shape[0], shape[1])
         return target
 
+
+
     def show(self, idx, show_type='all'):
         image, target = self.__getitem__(idx)
 
@@ -170,17 +177,18 @@ class LineDataset(BaseDataset):
         # print(f'boxes:{target["boxes"]}')
         img = image
 
-        if show_type == 'arc_masks':
+        if show_type == 'circle_masks':
             boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
                                               colors="yellow", width=1)
             # arc = target['arc']
-            arc_mask = target['arc_masks']
+            arc_mask = target['circle_masks']
             # print(f'taget circle:{arc.shape}')
             print(f'target circle_masks:{arc_mask.shape}')
-            plt.imshow(arc_mask.squeeze(0))
+            combined = torch.cat(list(arc_mask), dim=1)
+            plt.imshow(combined)
             plt.show()
 
-        if show_type == 'circle_masks':
+        if show_type == 'circle_masks11':
             boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
                                               colors="yellow", width=1)
             circle = target['circles']
@@ -267,7 +275,7 @@ def arc_to_mask(arc7, shape, line_width=1):
     """
     # print_params(arc7)
     # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
-    if torch.all(torch.tensor(arc7) == 0):
+    if torch.all(arc7 == 0):
         return torch.zeros(shape, dtype=torch.uint8)
 
     xc, yc, a, b, theta, phi1, phi2 = arc7
@@ -554,6 +562,6 @@ def get_boxes_lines(objs, shape):
 
 
 if __name__ == '__main__':
-    path = r'\\192.168.50.222/share/lm/1112/a_dataset'
+    path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
     dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
-    dataset.show(9, show_type='arc_masks')
+    dataset.show(9, show_type='circle_masks')

+ 46 - 30
models/line_detect/loi_heads.py

@@ -1251,13 +1251,7 @@ class RoIHeads(nn.Module):
                     h, w = targets[0]["img_size"]
                     img_size = h
 
-                    # gt_arcs_tensor = torch.zeros(0, 0)
-                    # if len(gt_arcs) > 0:
-                    #     gt_arcs_tensor = torch.cat(gt_arcs)
-                    #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
-                    #
-                    # if gt_arcs_tensor.shape[0] > 0:
-                    #     print(f'start to compute point_loss')
+
                     if len(gt_arcs) > 0 and feature_logits is not None:
                         loss_arc = compute_ins_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
@@ -1277,13 +1271,7 @@ class RoIHeads(nn.Module):
                         h, w = targets[0]["img_size"]
                         img_size = h
 
-                        # gt_arcs_tensor = torch.zeros(0, 0)
-                        # if len(gt_arcs) > 0:
-                        #     gt_arcs_tensor = torch.cat(gt_arcs)
-                        #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
 
-                        # if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
-                        #     print(f'start to compute arc_loss')
 
                         if len(gt_arcs) > 0 and feature_logits is not None:
                             print(f'start to compute arc_loss')
@@ -1341,9 +1329,26 @@ class RoIHeads(nn.Module):
                 if matched_idxs is None:
                     raise ValueError("if in trainning, matched_idxs should not be None")
                 for img_id in range(num_images):
+
+                    # circle_pos = torch.where(labels[img_id] == 4)[0]
+                    # ins_proposals.append(proposals[img_id][circle_pos])
+                    # ins_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
                     circle_pos = torch.where(labels[img_id] == 4)[0]
-                    ins_proposals.append(proposals[img_id][circle_pos])
-                    ins_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
+                    circle_pos = circle_pos.flatten()
+                    idxs = circle_pos.detach().cpu().tolist()
+                    num_prop = len(proposals[img_id])
+                    for idx in idxs:
+                        if idx < 0 or idx >= num_prop:
+                            raise RuntimeError(
+                                f"Index out of bounds: circle_pos={idx}, but proposals len={num_prop}, "
+                                f"img_id={img_id}"
+                            )
+                    ins_proposals.append(
+                        proposals[img_id][idxs]
+                    )
+                    ins_pos_matched_idxs.append(
+                        matched_idxs[img_id][idxs]
+                    )
             else:
                 if targets is not None:
 
@@ -1356,9 +1361,25 @@ class RoIHeads(nn.Module):
                         raise ValueError("if in trainning, matched_idxs should not be None")
 
                     for img_id in range(num_images):
+                        # circle_pos = torch.where(labels[img_id] == 4)[0]
+                        # ins_proposals.append(proposals[img_id][circle_pos])
+                        # ins_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
                         circle_pos = torch.where(labels[img_id] == 4)[0]
-                        ins_proposals.append(proposals[img_id][circle_pos])
-                        ins_pos_matched_idxs.append(matched_idxs[img_id][circle_pos])
+                        circle_pos = circle_pos.flatten()
+                        idxs = circle_pos.detach().cpu().tolist()
+                        num_prop = len(proposals[img_id])
+                        for idx in idxs:
+                            if idx < 0 or idx >= num_prop:
+                                raise RuntimeError(
+                                    f"Index out of bounds: circle_pos={idx}, but proposals len={num_prop}, "
+                                    f"img_id={img_id}"
+                                )
+                        ins_proposals.append(
+                            proposals[img_id][idxs]
+                        )
+                        ins_pos_matched_idxs.append(
+                            matched_idxs[img_id][idxs]
+                        )
 
                 else:
                     pos_matched_idxs = None
@@ -1375,7 +1396,7 @@ class RoIHeads(nn.Module):
                 print(f'features from backbone:{features['0'].shape}')
                 feature_logits = self.ins_forward1(features, image_shapes, ins_proposals)
 
-                arc_equation = self.arc_equation_head(feature_logits)  # [proposal和,7]
+                # arc_equation = self.arc_equation_head(feature_logits)  # [proposal和,7]
 
                 loss_ins = None
                 loss_ins_extra=None
@@ -1409,12 +1430,7 @@ class RoIHeads(nn.Module):
                         print(f'start to compute circle_loss')
 
                         loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,ins_pos_matched_idxs)
-                        total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,
-                                                                                                 ins_proposals,
-                                                                                                 gt_mask_ends,
-                                                                                                 gt_mask_params,
-                                                                                                 ins_pos_matched_idxs,
-                                                                                                 labels)
+                        # total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_params,ins_pos_matched_idxs,labels)
                         loss_arc_ends = loss_arc_ends
                     if loss_arc_equation is None:
                         print(f'loss_arc_equation is None')
@@ -1456,8 +1472,8 @@ class RoIHeads(nn.Module):
 
                             loss_ins = compute_ins_loss(feature_logits, ins_proposals, gt_inses,
                                                            ins_pos_matched_idxs)
-                            total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(
-                                arc_equation, ins_proposals, gt_mask_ends, gt_mask_params, ins_pos_matched_idxs, labels)
+                            # total_loss, loss_arc_equation, loss_arc_ends = compute_arc_equation_loss(arc_equation,ins_proposals,gt_mask_params,ins_pos_matched_idxs,labels)
+
                             loss_arc_ends = loss_arc_ends
 
                             # loss_ins_extra = compute_circle_extra_losses(feature_logits, circle_proposals, gt_circles,circle_pos_matched_idxs)
@@ -1508,10 +1524,10 @@ class RoIHeads(nn.Module):
                             ins_masks, ins_scores, circle_points = ins_inference(feature_logits,
                                                                                          ins_proposals, th=0)
 
-                            arc7, arc_scores = arc_inference1(arc_equation, feature_logits, ins_proposals, 0.5)
-                            for arc_, arc_score, r in zip(arc7, arc_scores, result):
-                                r["arcs"] = arc_
-                                r["arc_scores"] = arc_score
+                            # arc7, arc_scores = arc_inference1(arc_equation, feature_logits, ins_proposals, 0.5)
+                            # for arc_, arc_score, r in zip(arc7, arc_scores, result):
+                            #     r["arcs"] = arc_
+                            #     r["arc_scores"] = arc_score
                             # print(f'circles_probs:{circles_probs.shape}, circles_scores:{circles_scores.shape}')
                             proposals_per_image = [box.size(0) for box in ins_proposals]
                             print(f'ins_proposals_per_image:{proposals_per_image}')

+ 59 - 0
utils/data_process/mask/show_mask.py

@@ -0,0 +1,59 @@
+import os
+import torch
+import numpy as np
+import matplotlib.pyplot as plt
+from glob import glob
+
+def save_full_mask(mask: torch.Tensor, name="mask", out_dir="output",
+                   save_png=True, save_txt=True, save_npy=False,
+                   max_files=100,
+                   save_all_zero=False):
+
+
+    if mask.dim() == 4:
+        mask = mask[0, 0]
+    elif mask.dim() == 3:
+        mask = mask[0]
+
+    if not save_all_zero and torch.all(mask == 0):
+        print(f"? {name} È« 0£¬ÒÑÌø¹ý±£´æ")
+        return
+
+    os.makedirs(out_dir, exist_ok=True)
+    mask_cpu = mask.detach().cpu()
+
+    print(f"\nSaving full mask: {name}")
+    print(f"shape = {tuple(mask.shape)}, dtype = {mask.dtype}, device = {mask.device}")
+
+    pattern = os.path.join(out_dir, f"{name}_*.png")
+    existing_files = sorted(glob(pattern))
+    if existing_files:
+        existing_idxs = [int(os.path.basename(f).split("_")[-1].split(".")[0]) for f in existing_files]
+        next_idx = max(existing_idxs) + 1
+    else:
+        next_idx = 0
+
+    if next_idx >= max_files:
+        next_idx = next_idx % max_files
+
+    file_idx_str = f"{next_idx:03d}"
+
+    if save_png:
+        png_path = os.path.join(out_dir, f"{name}_{file_idx_str}.png")
+        plt.figure()
+        plt.imshow(mask_cpu, cmap="gray")
+        plt.colorbar()
+        plt.title(name)
+        plt.savefig(png_path, dpi=300)
+        plt.close()
+        print(f"? saved PNG -> {png_path}")
+
+    if save_txt:
+        txt_path = os.path.join(out_dir, f"{name}_{file_idx_str}.txt")
+        np.savetxt(txt_path, mask_cpu.numpy(), fmt="%.3f")
+        print(f"? saved TXT -> {txt_path}")
+
+    if save_npy:
+        npy_path = os.path.join(out_dir, f"{name}_{file_idx_str}.npy")
+        np.save(npy_path, mask_cpu.numpy())
+        print(f"? saved NPY -> {npy_path}")