Переглянути джерело

tensorboad加入根据mask绘制椭圆功能

admin 1 місяць тому
батько
коміт
6ad8d910e6
1 змінених файлів з 77 додано та 11 видалено
  1. 77 11
      models/line_detect/trainer.py

+ 77 - 11
models/line_detect/trainer.py

@@ -15,11 +15,11 @@ from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
 from models.line_detect.line_dataset import LineDataset
+import torch.nn.functional as F
+
 
-from models.line_net.dataset_LD import WirePointDataset
-from models.wirenet.postprocess import postprocess
 from tools import utils
-from torchvision import transforms
+
 import matplotlib as mpl
 
 cmap = plt.get_cmap("jet")
@@ -44,12 +44,71 @@ def c(x):
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
-import matplotlib.pyplot as plt
-from PIL import ImageDraw
-from torchvision.transforms import functional as F
-import torch
 
 
+def draw_ellipses_on_image(image, masks_pred, threshold=0.5, color=(0, 255, 0), thickness=2):
+    """
+    在单张原始图像上绘制从 masks 拟合出的椭圆。
+    自动将 masks resize 到 image 的空间尺寸。
+
+    Args:
+        image: Tensor [3, H_img, W_img] —— 原始图像(如 [3, 2000, 2000])
+        masks_pred: Tensor [N, 1, H_mask, W_mask] or [N, H_mask, W_mask] —— 模型输出 mask(如 [2, 1, 672, 672])
+        threshold: 二值化阈值
+        color: BGR color for OpenCV
+        thickness: ellipse line thickness
+
+    Returns:
+        drawn_image: numpy array [H_img, W_img, 3] in RGB
+    """
+    # Step 1: 标准化 masks_pred to [N, H, W]
+    if masks_pred.ndim == 4:
+        if masks_pred.shape[1] == 1:
+            masks_pred = masks_pred.squeeze(1)  # [N, 1, H, W] -> [N, H, W]
+        else:
+            raise ValueError(f"Expected channel=1 in masks_pred, got shape {masks_pred.shape}")
+    elif masks_pred.ndim != 3:
+        raise ValueError(f"masks_pred must be 3D (N, H, W) or 4D (N, 1, H, W), got {masks_pred.shape}")
+
+    N, H_mask, W_mask = masks_pred.shape
+    C, H_img, W_img = image.shape
+
+    # Step 2: Resize masks to original image size using bilinear interpolation
+    masks_resized = F.interpolate(
+        masks_pred.unsqueeze(1).float(),  # [N, 1, H_mask, W_mask]
+        size=(H_img, W_img),
+        mode='bilinear',
+        align_corners=False
+    ).squeeze(1)  # [N, H_img, W_img]
+
+    # Step 3: Convert image to numpy RGB
+    img_tensor = image.detach().cpu()
+    if img_tensor.max() <= 1.0:
+        img_np = (img_tensor * 255).byte().numpy()  # [3, H, W]
+    else:
+        img_np = img_tensor.byte().numpy()
+    img_rgb = np.transpose(img_np, (1, 2, 0))  # [H, W, 3]
+    img_out = img_rgb.copy()
+
+    # Step 4: Process each mask
+    for mask in masks_resized:
+        mask_cpu = mask.detach().cpu()
+        mask_prob = torch.sigmoid(mask_cpu) if mask_cpu.min() < 0 else mask_cpu
+        binary = (mask_prob > threshold).numpy().astype(np.uint8) * 255  # [H_img, W_img]
+
+        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+        if contours:
+            largest_contour = max(contours, key=cv2.contourArea)
+            if len(largest_contour) >= 5:
+                try:
+                    ellipse = cv2.fitEllipse(largest_contour)
+                    img_bgr = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
+                    cv2.ellipse(img_bgr, ellipse, color=color, thickness=thickness)
+                    img_out = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+                except cv2.error as e:
+                    print(f"Warning: Failed to fit ellipse: {e}")
+
+    return img_out
 
 
 def fit_circle(points):
@@ -282,25 +341,32 @@ class Trainer(BaseTrainer):
             bbb=result['boxes']
             print(f'boxes shape:{bbb.shape}')
             print(f'ppp:{ppp.shape}')
-            points = result['ins_masks']
-            points = points.squeeze(1)
-            print(f'points shape:{points.shape}')
+            ins_masks = result['ins_masks']
+            ins_masks = ins_masks.squeeze(1)
+            print(f'ins_masks shape:{ins_masks.shape}')
             features = result['features']
 
             circle_image = img.cpu().numpy().transpose((1, 2, 0))  # CHW -> HWC
             circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
 
-            sum_mask = points.sum(dim=0, keepdim=True)
+            sum_mask = ins_masks.sum(dim=0, keepdim=True)
             sum_mask = sum_mask / (sum_mask.max() + 1e-8)
 
             # keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
             self.writer.add_image('z-ins-masks', sum_mask.squeeze(0), global_step=epoch)
+
+
+            result_imgs = draw_ellipses_on_image(img, ins_masks, threshold=0.5)
+            self.writer.add_image('z-out-ellipses', result_imgs, dataformats='HWC', global_step= epoch)
+
             features=self.apply_gaussian_blur_to_tensor(features,sigma=3)
             self.writer.add_image('z-feature', features, global_step=epoch)
 
             # cv2.imshow('arc', img_rgb)
             # cv2.waitKey(1000000)
 
+
+
     def normalize_tensor(self,tensor):
         """Normalize tensor to [0, 1]"""
         min_val = tensor.min()