admin před 1 měsícem
rodič
revize
759cfd3e3b

+ 3 - 1
models/line_detect/loi_heads.py

@@ -1448,9 +1448,11 @@ class RoIHeads(nn.Module):
                         if feature_logits is not None:
 
                             circles_probs, circles_scores = circle_inference(feature_logits, circle_proposals)
-                            for keypoint_prob, kps, r in zip(circles_probs, circles_scores, result):
+                            for keypoint_prob, kps, r,f in zip(circles_probs, circles_scores, result,feature_logits):
                                 r["circles"] = keypoint_prob
                                 r["circles_scores"] = kps
+                                print(f'circles feature map:{f.shape}')
+                                r["features"] = f
 
                 print(f'loss_circle:{loss_circle}')
                 print(f'loss_circle_extra:{loss_circle_extra}')

+ 2 - 0
models/line_detect/trainer.py

@@ -284,6 +284,7 @@ class Trainer(BaseTrainer):
             points = result['circles']
             points = points.squeeze(1)
             print(f'points shape:{points.shape}')
+            features = result['features']
 
             circle_image = img.cpu().numpy().transpose((1, 2, 0))  # CHW -> HWC
             circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
@@ -309,6 +310,7 @@ class Trainer(BaseTrainer):
 
             # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
             self.writer.add_image('z-out-circle', img_tensor, global_step=epoch)
+            self.writer.add_image('z-feature', features, global_step=epoch)
 
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)