Bläddra i källkod

调整inference输出的points结构,可直接调用drawkeypoint

RenLiqiang 5 månader sedan
förälder
incheckning
b5faea6834
3 ändrade filer med 34 tillägg och 17 borttagningar
  1. 14 6
      models/line_detect/line_dataset.py
  2. 3 1
      models/line_detect/loi_heads.py
  3. 17 10
      models/line_detect/trainer.py

+ 14 - 6
models/line_detect/line_dataset.py

@@ -108,16 +108,23 @@ class LineDataset(BaseDataset):
         sm.set_array([])
 
         # img_path = os.path.join(self.img_path, self.imgs[idx])
+        # print(f'boxes:{target["boxes"]}')
         img = image
         if show_type=='all':
             boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
                                                   colors="yellow", width=1)
-            keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
+            keypoint_img=draw_keypoints(boxed_image,target['points'].unsqueeze(1),colors='red',width=3)
             plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
             plt.show()
 
-        if show_type=='lines':
-            keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
+        # if show_type=='lines':
+        #     keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
+        #     plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
+        #     plt.show()
+
+        if show_type=='points':
+            print(f'points:{target['points'].shape}')
+            keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['points'].unsqueeze(1),colors='red',width=3)
             plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
             plt.show()
 
@@ -164,7 +171,7 @@ def get_boxes_lines(objs,shape):
              xmin=max(0,p[0]-6)
              xmax = min(w, p[0] +6)
              ymin=max(0,p[1]-6)
-             ymax = max(h, p[1] + 6)
+             ymax = min(h, p[1] + 6)
 
              points.append(p)
              labels.append(torch.tensor(1))
@@ -177,6 +184,7 @@ def get_boxes_lines(objs,shape):
             labels.append(torch.tensor(3))
 
     boxes=torch.tensor(boxes)
+    print(f'boxes:{boxes.shape}')
     labels=torch.tensor(labels)
     points=torch.tensor(points)
     # print(f'read labels:{labels}')
@@ -187,6 +195,6 @@ def get_boxes_lines(objs,shape):
     return boxes,line_point_pairs,points,labels
 
 if __name__ == '__main__':
-    path=r"\\192.168.50.222/share/rlq/datasets/0706_"
-    dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=True, data_type='jpg')
+    path=r"\\192.168.50.222\share\rlq\datasets\Dataset0709_"
+    dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg')
     dataset.show(1,show_type='all')

+ 3 - 1
models/line_detect/loi_heads.py

@@ -868,7 +868,7 @@ def line_inference(x, boxes):
         lines_probs.append(line_prob)
         lines_scores.append(line_scores)
 
-        points_probs.append(point_prob)
+        points_probs.append(point_prob.unsqueeze(1))
         points_scores.append(point_scores)
 
     return lines_probs, lines_scores,points_probs,points_scores
@@ -1568,6 +1568,7 @@ class RoIHeads(nn.Module):
                     loss_point={"loss_point":loss_point}
 
 
+
                 else:
                     if line_logits is None or line_proposals is None:
                         raise ValueError(
@@ -1577,6 +1578,7 @@ class RoIHeads(nn.Module):
                     lines_probs, lines_scores,point_probs,points_scores = line_inference(line_logits, line_proposals)
 
                     for keypoint_prob, kps, points,ps,r in zip(lines_probs, lines_scores,point_probs,points_scores, result):
+                        print(f'points_prob :{points.shape}')
                         r["lines"] = keypoint_prob
                         r["liness_scores"] = kps
                         r["points"] = points

+ 17 - 10
models/line_detect/trainer.py

@@ -191,7 +191,7 @@ class Trainer(BaseTrainer):
 
 
 
-    def writer_predict_result(self, img, result, epoch):
+    def writer_predict_result(self, img, result, epoch,type=1):
         img = img.cpu().detach()
         im = img.permute(1, 2, 0)  # [512, 512, 3]
         self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
@@ -203,19 +203,24 @@ class Trainer(BaseTrainer):
         # plt.show()
 
         self.writer.add_image("z-obj", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
-        keypoint_img = draw_keypoints(boxed_image, result['lines'], colors='red', width=3)
 
-        self.writer.add_image("z-output", keypoint_img, epoch)
-        print("lines shape:", result['lines'].shape)
 
-        # 用自己写的函数画线段
-        # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
-        print(f"shape of linescore:{result['liness_scores'].shape}")
-        scores = result['liness_scores'].mean(dim=1)  # shape: [31]
+        if type==1:
+            keypoint_img = draw_keypoints(boxed_image, result['points'], colors='red', width=3)
 
-        line_image = draw_lines_with_scores((img * 255).to(torch.uint8),  result['lines'],scores, width=3, cmap='jet')
+            self.writer.add_image("z-output", keypoint_img, epoch)
+        # print("lines shape:", result['lines'].shape)
 
-        self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+        if type==2:
+            # 用自己写的函数画线段
+            # line_image = draw_lines(boxed_image, result['lines'], color='red', width=3)
+            print(f"shape of linescore:{result['liness_scores'].shape}")
+            scores = result['liness_scores'].mean(dim=1)  # shape: [31]
+
+            line_image = draw_lines_with_scores((img * 255).to(torch.uint8),  result['lines'],scores, width=3, cmap='jet')
+
+            self.writer.add_image("z-output_line", line_image.permute(1, 2, 0), epoch, dataformats="HWC")
 
 
 
@@ -311,7 +316,9 @@ class Trainer(BaseTrainer):
             if phase== 'val':
                 result,loss_dict = model(imgs, targets)
                 losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)
+
                 print(f'val losses:{losses}')
+                print(f'val result:{result}')
             else:
                 loss_dict = model(imgs, targets)
                 losses = sum(loss_dict.values()) if loss_dict else torch.tensor(0.0, device=device)