Преглед на файлове

arc_inference提取弧线点坐标,并绘制在tensorboard

xue50 преди 4 месеца
родител
ревизия
cf0d88207b
променени са 4 файла, в които са добавени 66 реда и са изтрити 9 реда
  1. 36 2
      models/line_detect/heads/head_losses.py
  2. 4 2
      models/line_detect/loi_heads.py
  3. 2 2
      models/line_detect/train.yaml
  4. 24 3
      models/line_detect/trainer.py

+ 36 - 2
models/line_detect/heads/head_losses.py

@@ -789,7 +789,9 @@ def line_inference(x, line_boxes):
 
     return lines_probs, lines_scores
 
-def arc_inference(x, arc_boxes):
+
+
+def arc_inference(x, arc_boxes,th):
     # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
 
     points_probs = []
@@ -810,7 +812,39 @@ def arc_inference(x, arc_boxes):
 
     points_probs_tensor=torch.cat(points_probs)
     print(f'points_probs shape:{points_probs_tensor.shape}')
-    return points_probs,points_scores
+
+    feature_logits = x
+    batch_size = feature_logits.shape[0]
+    num_proposals = len(arc_boxes[0])
+
+    results = [[torch.empty(0, 2) for _ in range(num_proposals)] for _ in range(batch_size)]
+    proposals_list = arc_boxes[0]  # [[tensor(...)]]
+
+    for proposal_idx, proposal in enumerate(proposals_list):
+        coords = proposal.tolist()
+        x1, y1, x2, y2 = map(int, coords)
+
+        x1 = max(0, x1)
+        y1 = max(0, y1)
+        x2 = min(feature_logits.shape[3], x2)
+        y2 = min(feature_logits.shape[2], y2)
+
+        for batch_idx in range(batch_size):
+            region = feature_logits[batch_idx, :, y1:y2, x1:x2]
+            mask = region > th
+            coords = torch.nonzero(mask)
+
+            if coords.numel() > 0:
+                # 取 (y, x),然后转换为全局坐标 (x, y)
+                local_coords = coords[:, [2, 1]]  # (x, y)
+                local_coords[:, 0] += x1
+                local_coords[:, 1] += y1
+
+                results[batch_idx][proposal_idx] = local_coords
+
+    print(f're:{results}')
+
+    return points_probs,points_scores,results
 
 
 import torch.nn.functional as F

+ 4 - 2
models/line_detect/loi_heads.py

@@ -1272,11 +1272,13 @@ class RoIHeads(nn.Module):
 
                     if feature_logits is not None:
 
-                        arcs_probs, arcs_scores = arc_inference(feature_logits,arc_proposals)
-                        for keypoint_prob, kps, r in zip(arcs_probs, arcs_scores, result):
+                        arcs_probs, arcs_scores, arcs_point = arc_inference(feature_logits,arc_proposals, th=0)
+                        for keypoint_prob, kps, kp, r in zip(arcs_probs, arcs_scores, arcs_point, result):
                             # r["arcs"] = keypoint_prob
                             r["arcs"] = feature_logits
                             r["arcs_scores"] = kps
+                            r["arcs_point"] = kp
+
 
             # print(f'loss_point:{loss_point}')
             losses.update(loss_arc)

+ 2 - 2
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
 
-  datadir: /data/share/zyh/arc/a_datasetb
+  datadir: \\192.168.50.222\share\rlq\datasets\arc_datasets_100
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
@@ -20,7 +20,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 2
+  batch_size: 1
   max_epoch: 8000000
 #  augmentation: True
   augmentation: False

+ 24 - 3
models/line_detect/trainer.py

@@ -226,15 +226,32 @@ class Trainer(BaseTrainer):
         if 'arcs' in result:
             arcs = result['arcs'][0]
 
-
             # img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
 
-
-
             # img_tensor =torch.tensor(img_rgb)
             # img_tensor = np.transpose(img_tensor)
             self.writer.add_image('z-out-arc', arcs, global_step=epoch)
 
+            aa = result['arcs_point'][0]
+
+            x_coords = aa[:, 0].cpu()/800*2000
+            y_coords = aa[:, 1].cpu()/800*2000
+
+            plt.figure(figsize=(10, 8))
+            plt.imshow(im)
+            plt.scatter(x_coords, y_coords, c='red', s=0.3, label='Arc Points')
+            plt.title("Image with Arc Points")
+            plt.legend()
+            plt.axis('off')
+
+            fig = plt.gcf()
+            fig.canvas.draw()
+            image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+            image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))  # H x W x 3
+            plt.close()
+
+            self.writer.add_image('z-out-result', image_from_plot, dataformats='HWC')
+
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)
 
@@ -290,6 +307,10 @@ class Trainer(BaseTrainer):
             weight_decay=kwargs['train_params']['optim']['weight_decay'],
 
         )
+
+        model, optimizer = self.load_best_model(model, optimizer,
+                                                r"\\192.168.50.222\share\rlq\weights\250725_arc_res152_best_val.pth",
+                                                device)
         # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
         scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)