import cv2 import imageio import numpy as np from skimage.draw import ellipse from torch.utils.data.dataset import T_co from libs.vision_libs.utils import draw_keypoints from models.base.base_dataset import BaseDataset import json import os import PIL import matplotlib as mpl from torchvision.utils import draw_bounding_boxes import torchvision.transforms.v2 as transforms import torch import matplotlib.pyplot as plt from models.base.transforms import get_transforms from utils.data_process.mask.show_mask import save_full_mask from utils.data_process.show_prams import print_params def validate_keypoints(keypoints, image_width, image_height): for kp in keypoints: x, y, v = kp if not (0 <= x < image_width and 0 <= y < image_height): raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})") """ 直接读取xanlabel标注的数据集json格式 """ class LineDataset(BaseDataset): def __init__(self, dataset_path, data_type, transforms=None, augmentation=False, dataset_type=None, img_type='rgb', target_type='pixel'): super().__init__(dataset_path) self.data_path = dataset_path self.data_type = data_type print(f'data_path:{dataset_path}') self.transforms = transforms self.img_path = os.path.join(dataset_path, "images/" + dataset_type) self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type) self.imgs = os.listdir(self.img_path) self.lbls = os.listdir(self.lbl_path) self.target_type = target_type self.img_type = img_type self.augmentation = augmentation print(f'augmentation:{augmentation}') # self.default_transform = DefaultTransform() def __getitem__(self, index) -> T_co: img_path = os.path.join(self.img_path, self.imgs[index]) if self.data_type == 'tiff': lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json') img = imageio.v3.imread(img_path)[:, :, 0] print(f'img shape:{img.shape}') w, h = img.shape[:2] img = img.reshape(w, h, 1) img_3channel = np.zeros((w, h, 3), dtype=img.dtype) img_3channel[:, :, 2] = img[:, :, 0] img = torch.from_numpy(img_3channel).permute(2, 1, 0) else: lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json') img = PIL.Image.open(img_path).convert('RGB') w, h = img.size # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w)) print(img_path) print_params(img) target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img) self.transforms = get_transforms(augmention=self.augmentation) img, target = self.transforms(img, target) return img, target def __len__(self): return len(self.imgs) def read_target(self, item, lbl_path, shape, extra=None,image=None): # print(f'shape:{shape}') # print(f'lbl_path:{lbl_path}') with open(lbl_path, 'r') as file: lable_all = json.load(file) objs = lable_all["shapes"] point_pairs = objs[0]['points'] # print(f'point_pairs:{point_pairs}') target = {} target["image_id"] = torch.tensor(item) #boxes, line_point_pairs, points, labels, mask_ends, mask_params boxes, lines, points, labels, arc_ends, arc_params = get_boxes_lines(objs, shape) # print_params(arc_ends, arc_params) # if points is not None: # target["points"] = points # if lines is not None: # a = torch.full((lines.shape[0],), 2).unsqueeze(1) # lines = torch.cat((lines, a), dim=1) # target["lines"] = lines.to(torch.float32).view(-1, 2, 3) if lines is not None: label_3d = labels.view(-1, 1, 1).expand(-1, 2, -1) # [N] -> [N,2,1] line1 = torch.cat([lines, label_3d], dim=-1) # [N,2,3] target["lines"] = line1.to(torch.float32) if arc_ends is not None: target['mask_ends'] = arc_ends if arc_params is not None: target['mask_params'] = arc_params # arc_angles = compute_arc_angles(arc_ends, arc_params) # print_params(arc_ends,arc_params) arc_masks = [] for i in range(len(arc_params)): mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000),debug=False) arc_masks.append(mask) # print_params(arc_masks) target['circle_masks'] = torch.stack(arc_masks, dim=0) # save_full_mask(torch.stack(arc_masks, dim=0), "arc_masks", # "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset", # force_save=False,image=image,show_on_image=True) target["boxes"] = boxes target["labels"] = labels # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape) # print(f'lines:{lines}') # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64) # print(f'target points:{target["points"]}') # target["lines"] = lines.to(torch.float32).view(-1,2,3) # print(f'') # print(f'lines:{target["lines"].shape}') target["img_size"] = shape # validate_keypoints(lines, shape[0], shape[1]) return target def show(self, idx, show_type='all'): image, target = self.__getitem__(idx) cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) img = image # print(f'img:{img.shape}') if show_type == 'arc_yuan_point_ellipse': arc_ends = target['mask_ends'] arc_params = target['mask_params'] fig, ax = plt.subplots() ax.imshow(img.permute(1, 2, 0)) for params in arc_params: if torch.all(params == 0): continue x, y, a, b, q = params theta = np.radians(q) phi = np.linspace(0, 2 * np.pi, 500) x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta) y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta) plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2) for point2 in arc_ends: if torch.all(point2 == 0): continue ends_np = point2.cpu().numpy() ax.plot(ends_np[:, 0], ends_np[:, 1], 'ro', markersize=6, label='Arc Endpoints') ax.legend() plt.axis('image') # 保持比例一致 plt.show() if show_type == 'circle_masks': boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) # arc = target['arc'] arc_mask = target['circle_masks'] # print(f'taget circle:{arc.shape}') print(f'target circle_masks:{arc_mask.shape}') combined = torch.cat(list(arc_mask), dim=1) plt.imshow(combined) plt.show() if show_type == 'circle_masks11': boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) circle = target['circles'] circle_mask = target['circle_masks'] print(f'taget circle:{circle.shape}') print(f'target circle_masks:{circle_mask.shape}') plt.imshow(circle_mask.squeeze(0)) keypoint_img = draw_keypoints(boxed_image, circle, colors='red', width=3) # plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) plt.show() # if show_type=='lines': # keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3) # plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) # plt.show() if show_type == 'points': # print(f'points:{target['points'].shape}') keypoint_img = draw_keypoints((img * 255).to(torch.uint8), target['points'].unsqueeze(1), colors='red', width=3) plt.imshow(keypoint_img.permute(1, 2, 0).numpy()) plt.show() if show_type == 'boxes': boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"], colors="yellow", width=1) plt.imshow(boxed_image.permute(1, 2, 0).numpy()) plt.show() def show_img(self, img_path): pass import torch import numpy as np import cv2 def arc_to_mask_safe(arc_param, arc_end, shape, line_width=5, debug=True, idx=-1): """ Generate a mask for a small (<180 degree) arc based on arc parameters and endpoints. Args: arc_param: torch.Tensor of shape (5,) - [cx, cy, a, b, theta] arc_end: torch.Tensor of shape (2,2) - [[x1,y1],[x2,y2]] shape: tuple (H,W) - mask size line_width: thickness of the arc debug: bool - if True, print debug info idx: int or str - index for debugging identification Returns: mask: torch.Tensor of shape (H,W) """ # ------------------ Check for all-zero input ------------------ if torch.all(arc_param == 0) or torch.all(arc_end == 0): if debug: print(f"[{idx}] Warning: arc_param or arc_end all zeros. Returning zero mask.") print(f"[{idx}] arc_param: {arc_param.tolist()}") print(f"[{idx}] arc_end: {arc_end.tolist()}") return torch.zeros(shape, dtype=torch.float32) cx, cy, a, b, theta = arc_param.tolist() if a <= 0 or b <= 0: if debug: print(f"[{idx}] Warning: invalid ellipse axes a={a}, b={b}. Returning zero mask.") print(f"[{idx}] arc_param: {arc_param.tolist()}") print(f"[{idx}] arc_end: {arc_end.tolist()}") return torch.zeros(shape, dtype=torch.float32) x1, y1 = arc_end[0].tolist() x2, y2 = arc_end[1].tolist() cos_t = np.cos(theta) sin_t = np.sin(theta) def point_to_angle(x, y): dx = x - cx dy = y - cy x_ = cos_t * dx + sin_t * dy y_ = -sin_t * dx + cos_t * dy return np.arctan2(y_ / b, x_ / a) try: angle1 = point_to_angle(x1, y1) angle2 = point_to_angle(x2, y2) except Exception as e: if debug: print(f"[{idx}] Exception in point_to_angle: {e}") print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}") return torch.zeros(shape, dtype=torch.float32) if np.isnan(angle1) or np.isnan(angle2): if debug: print(f"[{idx}] Warning: angle1 or angle2 is NaN. Returning zero mask.") print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}") return torch.zeros(shape, dtype=torch.float32) # Ensure small arc (<180 degrees) if angle2 < angle1: angle2 += 2 * np.pi if angle2 - angle1 > np.pi: angle1, angle2 = angle2, angle1 + 2 * np.pi angles = np.linspace(angle1, angle2, 100) xs = cx + a * np.cos(angles) * cos_t - b * np.sin(angles) * sin_t ys = cy + a * np.cos(angles) * sin_t + b * np.sin(angles) * cos_t xs = np.nan_to_num(xs, nan=0.0).astype(np.int32) ys = np.nan_to_num(ys, nan=0.0).astype(np.int32) # ------------------ Debug prints ------------------ if debug: print(f"[{idx}] arc_param: {arc_param.tolist()}") print(f"[{idx}] arc_end: {arc_end.tolist()}") print(f"[{idx}] xs[:5], ys[:5]: {xs[:5]}, {ys[:5]}") mask = np.zeros(shape, dtype=np.uint8) pts = np.stack([xs, ys], axis=1) # Draw the arc with given line_width for i in range(len(pts) - 1): cv2.line(mask, tuple(pts[i]), tuple(pts[i + 1]), color=1, thickness=line_width) # ------------------ Extra check for non-zero mask ------------------ if debug: mask_sum = mask.sum() if mask_sum == 0: print(f"[{idx}] Warning: mask generated is all zeros!") else: print(f"[{idx}] mask sum: {mask_sum}") return torch.tensor(mask, dtype=torch.float32) def draw_el(all): # 解析椭圆参数 if isinstance(all, torch.Tensor): all = all.cpu().numpy() x, y, a, b, q, q1, q2 = all theta = np.radians(q) phi1 = np.radians(q1) # 第一个点的参数角 phi2 = np.radians(q2) # 第二个点的参数角 # 生成椭圆上的点 phi = np.linspace(0, 2 * np.pi, 500) x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta) y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta) # 计算两个指定点的坐标 def param_to_point(phi, xc, yc, a, b, theta): x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta) y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta) return x, y P1 = param_to_point(phi1, x, y, a, b, theta) P2 = param_to_point(phi2, x, y, a, b, theta) # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C]) plt.figure(figsize=(10, 10)) # plt.imshow(background_img) # 直接显示背景图 # 绘制椭圆及相关元素 plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2) plt.plot(x, y, 'ko', markersize=8) plt.plot(P1[0], P1[1], 'ro', markersize=10) plt.plot(P2[0], P2[1], 'go', markersize=10) plt.show() def arc_to_mask(arc7, shape, line_width=1): """ Generate a binary mask of an elliptical arc. Args: xc, yc (float): 椭圆中心 a, b (float): 长半轴、短半轴 (a >= b) theta (float): 椭圆旋转角度(**弧度**,逆时针,相对于 x 轴) phi1, phi2 (float): 起始和终止参数角(**弧度**,在 [0, 2π) 内) H, W (int): 输出 mask 的高度和宽度 line_width (int): 弧线宽度(像素) Returns: mask (Tensor): [H, W], dtype=torch.uint8, 0/255 """ # print_params(arc7) # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况) if torch.all(arc7 == 0): return torch.zeros(shape, dtype=torch.uint8) xc, yc, a, b, theta, phi1, phi2 = arc7 H, W = shape if phi2 < phi1: phi2 += 2 * np.pi # 生成参数角(足够密集,避免断线) num_points = max(int(200 * abs(phi2 - phi1) / (2 * np.pi)), 10) phi = np.linspace(phi1, phi2, num_points) # 椭圆参数方程(先在未旋转坐标系下计算) x_local = a * np.cos(phi) y_local = b * np.sin(phi) # 应用旋转和平移 cos_t = np.cos(theta) sin_t = np.sin(theta) x_rot = x_local * cos_t - y_local * sin_t + xc y_rot = x_local * cos_t + y_local * sin_t + yc # 转为整数坐标(OpenCV 需要 int32) points = np.stack([x_rot, y_rot], axis=1).astype(np.int32) # 创建空白图像 img = np.zeros((H, W), dtype=np.uint8) # 绘制折线(antialias=False 更适合 mask) cv2.polylines(img, [points], isClosed=False, color=255, thickness=line_width, lineType=cv2.LINE_AA) return torch.from_numpy(img).float() # [H, W], values: 0 or 255 def compute_arc_angles(gt_mask_ends, gt_mask_params): """ 给定椭圆上的一个点,计算其对应的参数角 phi(弧度)。 Parameters: point: tuple or array-like, (x, y) ellipse_param: tuple or array-like, (xc, yc, a, b, theta) Returns: phi: float, in [0, 2*pi) """ # print_params(gt_mask_ends, gt_mask_params) results = [] if not isinstance(gt_mask_params, torch.Tensor): gt_mask_params_tensor = torch.tensor(gt_mask_params, dtype=gt_mask_ends.dtype, device=gt_mask_ends.device) else: gt_mask_params_tensor = gt_mask_params.clone().detach().to(gt_mask_ends) for ends_img, params_img in zip(gt_mask_ends, gt_mask_params_tensor): # print(f'params_img:{params_img}') if torch.norm(params_img) < 1e-6: # L2 norm near zero results.append(torch.zeros(2, device=params_img.device, dtype=params_img.dtype)) continue x, y = ends_img xc, yc, a, b, theta = params_img # 1. 平移到中心 dx = x - xc dy = y - yc # 2. 逆旋转(旋转 -theta) cos_t = torch.cos(theta) sin_t = torch.sin(theta) X = dx * cos_t + dy * sin_t Y = -dx * sin_t + dy * cos_t # 3. 归一化到单位圆(除以 a, b) cos_phi = X / a sin_phi = Y / b # 4. 用 atan2 求角度(自动处理象限) phi = torch.atan2(sin_phi, cos_phi) # 5. 转换到 [0, 2π) phi = torch.where(phi < 0, phi + 2 * torch.pi, phi) results.append(phi) return results def points_to_ellipse(points): """ 根据提供的四个点估计椭圆参数。 :param points: Tensor of shape (4, 2) 表示椭圆上的四个点 :return: 返回 (cx, cy, r1, r2, orientation) 其中 cx, cy 是中心坐标,r1, r2 分别是长轴和短轴半径,orientation 是椭圆的方向(弧度) """ # 转换为numpy数组进行计算 pts = points.numpy() pts = pts.reshape(-1, 2) center = np.mean(pts, axis=0) A = np.hstack( [pts[:, 0:1] ** 2, pts[:, 0:1] * pts[:, 1:2], pts[:, 1:2] ** 2, pts[:, :2], np.ones((pts.shape[0], 1))]) b = np.ones(pts.shape[0]) x = np.linalg.lstsq(A, b, rcond=None)[0] # 解析解参见 https://en.wikipedia.org/wiki/Ellipse#General_ellipse a, b, c, d, f, g = x.ravel() numerator = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g) denominator1 = (b * b - a * c) * ((c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a)) denominator2 = (b * b - a * c) * ((a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a)) major_axis = np.sqrt(numerator / denominator1) minor_axis = np.sqrt(numerator / denominator2) distances = np.linalg.norm(pts - center, axis=1) long_axis_length = np.max(distances) * 2 short_axis_length = np.min(distances) * 2 orientation = np.arctan2(pts[1, 1] - pts[0, 1], pts[1, 0] - pts[0, 0]) return center[0], center[1], long_axis_length / 2, short_axis_length / 2, orientation def generate_ellipse_mask(shape, ellipse_params): """ 在指定形状的图像上生成椭圆mask。 :param shape: 输出mask的形状 (HxW) :param ellipse_params: 椭圆参数 (cx, cy, rx, ry, orientation) :return: 椭圆mask """ cx, cy, rx, ry, orientation = ellipse_params img = np.zeros(shape, dtype=np.uint8) cx, cy, rx, ry = int(cx), int(cy), int(rx), int(ry) rr, cc = ellipse(cy, cx, ry, rx, shape) img[rr, cc] = 1 return img def sort_points_clockwise(points): points = np.array(points) top_left_idx = np.lexsort((points[:, 0], points[:, 1]))[0] reference_point = points[top_left_idx] def angle_to_reference(point): return np.arctan2(point[1] - reference_point[1], point[0] - reference_point[0]) angles = np.apply_along_axis(angle_to_reference, 1, points) angles[angles < 0] += 2 * np.pi sorted_indices = np.argsort(angles) sorted_points = points[sorted_indices] return sorted_points.tolist() def get_boxes_lines(objs, shape): boxes = [] labels = [] h, w = shape line_point_pairs = [] points = [] mask_ends = [] mask_params = [] for obj in objs: # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小 # print(f"points:{obj['points']}") label = obj['label'] if label == 'line' or label == 'dseam1': a, b = obj['points'][0], obj['points'][1] # line_point_pairs.append(a) # line_point_pairs.append(b) line_point_pairs.append([a, b]) xmin = max(0, (min(a[0], b[0]) - 6)) xmax = min(w, (max(a[0], b[0]) + 6)) ymin = max(0, (min(a[1], b[1]) - 6)) ymax = min(h, (max(a[1], b[1]) + 6)) boxes.append([xmin, ymin, xmax, ymax]) labels.append(torch.tensor(2)) points.append(torch.tensor([0.0])) mask_ends.append([[0, 0], [0, 0]]) mask_params.append([0, 0, 0, 0, 0]) # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]]) elif label == 'point': p = obj['points'][0] xmin = max(0, p[0] - 12) xmax = min(w, p[0] + 12) ymin = max(0, p[1] - 12) ymax = min(h, p[1] + 12) points.append(p) labels.append(torch.tensor(1)) boxes.append([xmin, ymin, xmax, ymax]) line_point_pairs.append([[0, 0], [0, 0]]) mask_ends.append([[0, 0], [0, 0]]) mask_params.append([0, 0, 0, 0, 0]) # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]]) # elif label == 'arc': # arc_points = obj['points'] # arc_params = obj['params'] # arc_ends = obj['ends'] # line_mask.append(arc_points) # mask_ends.append(arc_ends) # mask_params.append(arc_params) # # xs = [p[0] for p in arc_points] # ys = [p[1] for p in arc_points] # xmin, xmax = min(xs), max(xs) # ymin, ymax = min(ys), max(ys) # # boxes.append([xmin, ymin, xmax, ymax]) # labels.append(torch.tensor(3)) # # points.append(torch.tensor([0.0])) # line_point_pairs.append([[0, 0], [0, 0]]) # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]]) elif label == 'arc': arc_params = obj['params'] arc_ends = obj['ends'] mask_ends.append(arc_ends) mask_params.append(arc_params) arc3points = obj['points'] xs = [p[0] for p in arc3points] ys = [p[1] for p in arc3points] xmin_raw = min(xs) xmax_raw = max(xs) ymin_raw = min(ys) ymax_raw = max(ys) xmin = max(xmin_raw - 40, 0) xmax = min(xmax_raw + 40, w) ymin = max(ymin_raw - 40, 0) ymax = min(ymax_raw + 40, h) boxes.append([xmin, ymin, xmax, ymax]) labels.append(torch.tensor(4)) points.append(torch.tensor([0.0])) line_point_pairs.append([[0, 0], [0, 0]]) boxes = torch.tensor(boxes, dtype=torch.float32) print(f'boxes:{boxes.shape}') labels = torch.tensor(labels) if points: points = torch.tensor(points, dtype=torch.float32) else: points = None if line_point_pairs: line_point_pairs = torch.tensor(line_point_pairs, dtype=torch.float32) else: line_point_pairs = None if mask_ends: mask_ends = torch.tensor(mask_ends, dtype=torch.float32) else: mask_ends = None if mask_params: mask_params = torch.tensor(mask_params, dtype=torch.float32) else: mask_params = None return boxes, line_point_pairs, points, labels, mask_ends, mask_params if __name__ == '__main__': path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask' dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg') dataset.show(19, show_type='arc_yuan_point_ellipse')