瀏覽代碼

circle 特征图修改为高斯热图

admin 1 月之前
父節點
當前提交
cb0bac6c57

+ 2 - 2
models/line_detect/line_dataset.py

@@ -303,6 +303,6 @@ def get_boxes_lines(objs,shape):
     return boxes,line_point_pairs,points,line_mask,circle_4points, labels
 
 if __name__ == '__main__':
-    path=r'/data/share/zyh/master_dataset/circle/huayan_circle/a_dataset'
+    path=r'/data/share/zyh/master_dataset/circle/huyan_eclipse/a_dataset'
     dataset= LineDataset(dataset_path=path, dataset_type='train',augmentation=False, data_type='jpg')
-    dataset.show(9,show_type='all')
+    dataset.show(33,show_type='all')

+ 1 - 1
models/line_detect/train.yaml

@@ -7,7 +7,7 @@ io:
 #  datadir: /data/share/rlq/datasets/250718caisegangban
 
 #  datadir: /data/share/rlq/datasets/singepoint_Dataset0709_2
-  datadir: /data/share/zyh/master_dataset/circle/huayan_circle/a_dataset
+  datadir: /data/share/zyh/master_dataset/circle/huyan_eclipse/a_dataset
 #  datadir: \\192.168.50.222/share/rlq/datasets/singepoint_Dataset0709_2
 #  datadir: \\192.168.50.222/share/rlq/datasets/250718caisegangban
   data_type: rgb

+ 3 - 2
models/line_detect/train_demo.py

@@ -22,8 +22,9 @@ if __name__ == '__main__':
     # model = linedetect_newresnet101fpn(num_points=4)
     # model = linedetect_newresnet152fpn(num_points=4)
     # model.load_weights(save_path=r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250711_114046/weights/best_val.pth')
-    model=linedetect_maxvitfpn()
-    # model=linedetect_efficientnet(name='efficientnet_v2_l')
+    # model=linedetect_maxvitfpn()
+
+    model=linedetect_efficientnet(name='efficientnet_v2_l')
     # model=linedetect_high_maxvitfpn()
 
     # model=linedetect_swin_transformer_fpn(type='t')

+ 34 - 0
models/line_detect/trainer.py

@@ -6,6 +6,7 @@ import cv2
 import numpy as np
 import torch
 from matplotlib import pyplot as plt
+from scipy.ndimage import gaussian_filter
 from torch.optim.lr_scheduler import ReduceLROnPlateau
 from torch.utils.tensorboard import SummaryWriter
 
@@ -310,13 +311,46 @@ class Trainer(BaseTrainer):
 
             # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
             self.writer.add_image('z-out-circle', img_tensor, global_step=epoch)
+            features=self.apply_gaussian_blur_to_tensor(features,sigma=3)
             self.writer.add_image('z-feature', features, global_step=epoch)
 
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)
 
+    def normalize_tensor(self,tensor):
+        """Normalize tensor to [0, 1]"""
+        min_val = tensor.min()
+        max_val = tensor.max()
+        return (tensor - min_val) / (max_val - min_val)
 
+    def apply_gaussian_blur_to_tensor(self,feature_map, sigma=3):
+        """
+        Apply Gaussian blur to a feature map and convert it into an RGB heatmap.
 
+        :param feature_map: Tensor of shape (H, W) or (1, H, W)
+        :param sigma: Standard deviation for Gaussian kernel
+        :return: Tensor of shape (3, H, W) representing the RGB heatmap
+        """
+        if feature_map.dim() == 3:
+            if feature_map.shape[0] != 1:
+                raise ValueError("Only single-channel feature map supported.")
+            feature_map = feature_map.squeeze(0)
+
+        # Normalize to [0, 1]
+        normalized_feat = self.normalize_tensor(feature_map).cpu().numpy()
+
+        # Apply Gaussian blur
+        blurred_feat = gaussian_filter(normalized_feat, sigma=sigma)
+
+        # Convert to colormap (e.g., 'jet')
+        colormap = plt.get_cmap('jet')
+        colored = colormap(blurred_feat)  # shape: (H, W, 4) RGBA
+
+        # Convert to (3, H, W), drop alpha channel
+        colored_rgb = colored[:, :, :3]  # (H, W, 3)
+        colored_tensor = torch.from_numpy(colored_rgb).permute(2, 0, 1)  # (3, H, W)
+
+        return colored_tensor.float()
     def writer_loss(self, losses, epoch, phase='train'):
         try:
             for key, value in losses.items():