Ver Fonte

添加绘制arc功能

lstrlq há 4 meses atrás
pai
commit
dada562db1

+ 3 - 0
models/line_detect/heads/head_losses.py

@@ -768,6 +768,9 @@ def arc_inference(x, arc_boxes):
         points_probs.append(point_prob.unsqueeze(1))
         points_scores.append(point_scores)
 
+
+    points_probs_tensor=torch.cat(points_probs)
+    print(f'points_probs shape:{points_probs_tensor.shape}')
     return points_probs,points_scores
 
 

+ 1 - 1
models/line_detect/line_detect.py

@@ -458,7 +458,7 @@ def linedetect_newresnet152fpn(
     if num_points is None:
         num_points = 3
 
-    size=1024
+    size=800
     backbone =resnet101fpn(out_channels=256)
     featmap_names=['0', '1', '2', '3','4','pool']
     # print(f'featmap_names:{featmap_names}')

+ 9 - 5
models/line_detect/loi_heads.py

@@ -1239,15 +1239,19 @@ class RoIHeads(nn.Module):
                     h, w = targets[0]["img_size"]
                     img_size = h
 
-                    gt_arcs_tensor = torch.zeros(0, 0)
-                    if len(gt_arcs) > 0:
-                        gt_arcs_tensor = torch.cat(gt_arcs)
-                        print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
+                    # gt_arcs_tensor = torch.zeros(0, 0)
+                    # if len(gt_arcs) > 0:
+                    #     gt_arcs_tensor = torch.cat(gt_arcs)
+                    #     print(f'gt_arcs_tensor:{gt_arcs_tensor.shape}')
 
-                    if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
+                    # if gt_arcs_tensor.shape[0] > 0 and feature_logits is not None:
+                    #     print(f'start to compute arc_loss')
+
+                    if len(gt_arcs) > 0 and feature_logits is not None:
                         print(f'start to compute arc_loss')
                         loss_arc = compute_arc_loss(feature_logits, arc_proposals, gt_arcs, arc_pos_matched_idxs)
 
+
                     if loss_arc is None:
                         print(f'loss_arc is None111')
                         loss_arc = torch.tensor(0.0, device=device)

+ 2 - 2
models/line_detect/train.yaml

@@ -1,13 +1,13 @@
 io:
   logdir: train_results
 
-#  datadir: /data/zyh/py_ws/code/a_dataset
+  datadir: /data/share/rlq/datasets/arc_datasets
 #  datadir: /data/share/zjh/Dataset_correct_xanylabel_tiff
 
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
+#  datadir: \\192.168.50.222/share/zyh/arc/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb

+ 2 - 2
models/line_detect/train_demo.py

@@ -17,10 +17,10 @@ if __name__ == '__main__':
     # model = lineDetect_resnet18_fpn()
 
     # model=linedetect_resnet18_fpn()
-    model=linedetect_newresnet18fpn(num_points=3)
+    # model=linedetect_newresnet18fpn(num_points=3)
     # model=linedetect_newresnet50fpn(num_points=3)
     # model = linedetect_newresnet101fpn(num_points=3)
-    # model = linedetect_newresnet152fpn(num_points=3)
+    model = linedetect_newresnet152fpn(num_points=3)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
     # model=linedetect_maxvitfpn()
     # model=linedetect_high_maxvitfpn()

+ 24 - 0
models/line_detect/trainer.py

@@ -2,6 +2,7 @@ import os
 import time
 from datetime import datetime
 
+import cv2
 import numpy as np
 import torch
 from matplotlib import pyplot as plt
@@ -222,6 +223,29 @@ class Trainer(BaseTrainer):
 
             self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
+        if 'arcs' in result:
+            arcs = result['arcs']
+            img_np = img.numpy()
+            img_np=img_np.transpose(1,2,0)
+            # cv2.imshow('original', img_np*255)
+            # cv2.waitKey(100000)
+            for arc in arcs:
+                print(f'arc len:{len(arc)}')
+                for i in range(1, len(arc)):
+                    pt1 = (int(arc[i - 1][0]), int(arc[i - 1][1]))
+                    pt2 = (int(arc[i][0]), int(arc[i][1]))
+                    cv2.line(img_np, pt1, pt2, color=(255, 0, 0), thickness=2)
+
+            img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
+
+
+
+            img_tensor =torch.tensor(img_rgb)
+            img_tensor = np.transpose(img_tensor)
+            self.writer.add_image('z-out-arc', img_tensor, global_step=epoch)
+
+            # cv2.imshow('arc', img_rgb)
+            # cv2.waitKey(1000000)