|
|
@@ -15,11 +15,11 @@ from models.base.base_model import BaseModel
|
|
|
from models.base.base_trainer import BaseTrainer
|
|
|
from models.config.config_tool import read_yaml
|
|
|
from models.line_detect.line_dataset import LineDataset
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
|
|
|
-from models.line_net.dataset_LD import WirePointDataset
|
|
|
-from models.wirenet.postprocess import postprocess
|
|
|
from tools import utils
|
|
|
-from torchvision import transforms
|
|
|
+
|
|
|
import matplotlib as mpl
|
|
|
|
|
|
cmap = plt.get_cmap("jet")
|
|
|
@@ -44,12 +44,71 @@ def c(x):
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-from PIL import ImageDraw
|
|
|
-from torchvision.transforms import functional as F
|
|
|
-import torch
|
|
|
|
|
|
|
|
|
+def draw_ellipses_on_image(image, masks_pred, threshold=0.5, color=(0, 255, 0), thickness=2):
|
|
|
+ """
|
|
|
+ å¨åå¼ åå§å¾åä¸ç»å¶ä» masks æååºçæ¤åã
|
|
|
+ èªå¨å° masks resize å° image ç空é´å°ºå¯¸ã
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image: Tensor [3, H_img, W_img] ââ åå§å¾åï¼å¦ [3, 2000, 2000]ï¼
|
|
|
+ masks_pred: Tensor [N, 1, H_mask, W_mask] or [N, H_mask, W_mask] ââ æ¨¡åè¾åº maskï¼å¦ [2, 1, 672, 672]ï¼
|
|
|
+ threshold: äºå¼åéå¼
|
|
|
+ color: BGR color for OpenCV
|
|
|
+ thickness: ellipse line thickness
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ drawn_image: numpy array [H_img, W_img, 3] in RGB
|
|
|
+ """
|
|
|
+ # Step 1: æ åå masks_pred to [N, H, W]
|
|
|
+ if masks_pred.ndim == 4:
|
|
|
+ if masks_pred.shape[1] == 1:
|
|
|
+ masks_pred = masks_pred.squeeze(1) # [N, 1, H, W] -> [N, H, W]
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Expected channel=1 in masks_pred, got shape {masks_pred.shape}")
|
|
|
+ elif masks_pred.ndim != 3:
|
|
|
+ raise ValueError(f"masks_pred must be 3D (N, H, W) or 4D (N, 1, H, W), got {masks_pred.shape}")
|
|
|
+
|
|
|
+ N, H_mask, W_mask = masks_pred.shape
|
|
|
+ C, H_img, W_img = image.shape
|
|
|
+
|
|
|
+ # Step 2: Resize masks to original image size using bilinear interpolation
|
|
|
+ masks_resized = F.interpolate(
|
|
|
+ masks_pred.unsqueeze(1).float(), # [N, 1, H_mask, W_mask]
|
|
|
+ size=(H_img, W_img),
|
|
|
+ mode='bilinear',
|
|
|
+ align_corners=False
|
|
|
+ ).squeeze(1) # [N, H_img, W_img]
|
|
|
+
|
|
|
+ # Step 3: Convert image to numpy RGB
|
|
|
+ img_tensor = image.detach().cpu()
|
|
|
+ if img_tensor.max() <= 1.0:
|
|
|
+ img_np = (img_tensor * 255).byte().numpy() # [3, H, W]
|
|
|
+ else:
|
|
|
+ img_np = img_tensor.byte().numpy()
|
|
|
+ img_rgb = np.transpose(img_np, (1, 2, 0)) # [H, W, 3]
|
|
|
+ img_out = img_rgb.copy()
|
|
|
+
|
|
|
+ # Step 4: Process each mask
|
|
|
+ for mask in masks_resized:
|
|
|
+ mask_cpu = mask.detach().cpu()
|
|
|
+ mask_prob = torch.sigmoid(mask_cpu) if mask_cpu.min() < 0 else mask_cpu
|
|
|
+ binary = (mask_prob > threshold).numpy().astype(np.uint8) * 255 # [H_img, W_img]
|
|
|
+
|
|
|
+ contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
+ if contours:
|
|
|
+ largest_contour = max(contours, key=cv2.contourArea)
|
|
|
+ if len(largest_contour) >= 5:
|
|
|
+ try:
|
|
|
+ ellipse = cv2.fitEllipse(largest_contour)
|
|
|
+ img_bgr = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
|
|
|
+ cv2.ellipse(img_bgr, ellipse, color=color, thickness=thickness)
|
|
|
+ img_out = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
|
|
+ except cv2.error as e:
|
|
|
+ print(f"Warning: Failed to fit ellipse: {e}")
|
|
|
+
|
|
|
+ return img_out
|
|
|
|
|
|
|
|
|
def fit_circle(points):
|
|
|
@@ -282,25 +341,32 @@ class Trainer(BaseTrainer):
|
|
|
bbb=result['boxes']
|
|
|
print(f'boxes shape:{bbb.shape}')
|
|
|
print(f'ppp:{ppp.shape}')
|
|
|
- points = result['ins_masks']
|
|
|
- points = points.squeeze(1)
|
|
|
- print(f'points shape:{points.shape}')
|
|
|
+ ins_masks = result['ins_masks']
|
|
|
+ ins_masks = ins_masks.squeeze(1)
|
|
|
+ print(f'ins_masks shape:{ins_masks.shape}')
|
|
|
features = result['features']
|
|
|
|
|
|
circle_image = img.cpu().numpy().transpose((1, 2, 0)) # CHW -> HWC
|
|
|
circle_image = (circle_image * 255).clip(0, 255).astype(np.uint8)
|
|
|
|
|
|
- sum_mask = points.sum(dim=0, keepdim=True)
|
|
|
+ sum_mask = ins_masks.sum(dim=0, keepdim=True)
|
|
|
sum_mask = sum_mask / (sum_mask.max() + 1e-8)
|
|
|
|
|
|
# keypoint_img = draw_keypoints((img * 255).to(torch.uint8), points, colors='red', width=3)
|
|
|
self.writer.add_image('z-ins-masks', sum_mask.squeeze(0), global_step=epoch)
|
|
|
+
|
|
|
+
|
|
|
+ result_imgs = draw_ellipses_on_image(img, ins_masks, threshold=0.5)
|
|
|
+ self.writer.add_image('z-out-ellipses', result_imgs, dataformats='HWC', global_step= epoch)
|
|
|
+
|
|
|
features=self.apply_gaussian_blur_to_tensor(features,sigma=3)
|
|
|
self.writer.add_image('z-feature', features, global_step=epoch)
|
|
|
|
|
|
# cv2.imshow('arc', img_rgb)
|
|
|
# cv2.waitKey(1000000)
|
|
|
|
|
|
+
|
|
|
+
|
|
|
def normalize_tensor(self,tensor):
|
|
|
"""Normalize tensor to [0, 1]"""
|
|
|
min_val = tensor.min()
|