| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687 |
- import cv2
- import imageio
- import numpy as np
- from skimage.draw import ellipse
- from torch.utils.data.dataset import T_co
- from libs.vision_libs.utils import draw_keypoints
- from models.base.base_dataset import BaseDataset
- import json
- import os
- import PIL
- import matplotlib as mpl
- from torchvision.utils import draw_bounding_boxes
- import torchvision.transforms.v2 as transforms
- import torch
- import matplotlib.pyplot as plt
- from models.base.transforms import get_transforms
- from utils.data_process.mask.show_mask import save_full_mask
- from utils.data_process.show_prams import print_params
- def validate_keypoints(keypoints, image_width, image_height):
- for kp in keypoints:
- x, y, v = kp
- if not (0 <= x < image_width and 0 <= y < image_height):
- raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
- """
- 直接读取xanlabel标注的数据集json格式
- """
- class LineDataset(BaseDataset):
- def __init__(self, dataset_path, data_type, transforms=None, augmentation=False, dataset_type=None, img_type='rgb',
- target_type='pixel'):
- super().__init__(dataset_path)
- self.data_path = dataset_path
- self.data_type = data_type
- print(f'data_path:{dataset_path}')
- self.transforms = transforms
- self.img_path = os.path.join(dataset_path, "images/" + dataset_type)
- self.lbl_path = os.path.join(dataset_path, "labels/" + dataset_type)
- self.imgs = os.listdir(self.img_path)
- self.lbls = os.listdir(self.lbl_path)
- self.target_type = target_type
- self.img_type = img_type
- self.augmentation = augmentation
- print(f'augmentation:{augmentation}')
- # self.default_transform = DefaultTransform()
- def __getitem__(self, index) -> T_co:
- img_path = os.path.join(self.img_path, self.imgs[index])
- if self.data_type == 'tiff':
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
- img = imageio.v3.imread(img_path)[:, :, 0]
- print(f'img shape:{img.shape}')
- w, h = img.shape[:2]
- img = img.reshape(w, h, 1)
- img_3channel = np.zeros((w, h, 3), dtype=img.dtype)
- img_3channel[:, :, 2] = img[:, :, 0]
- img = torch.from_numpy(img_3channel).permute(2, 1, 0)
- else:
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
- img = PIL.Image.open(img_path).convert('RGB')
- w, h = img.size
- # wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
- print(img_path)
- print_params(img)
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
- self.transforms = get_transforms(augmention=self.augmentation)
- img, target = self.transforms(img, target)
- return img, target
- def __len__(self):
- return len(self.imgs)
- def read_target(self, item, lbl_path, shape, extra=None,image=None):
- # print(f'shape:{shape}')
- # print(f'lbl_path:{lbl_path}')
- with open(lbl_path, 'r') as file:
- lable_all = json.load(file)
- objs = lable_all["shapes"]
- point_pairs = objs[0]['points']
- # print(f'point_pairs:{point_pairs}')
- target = {}
- target["image_id"] = torch.tensor(item)
- #boxes, line_point_pairs, points, labels, mask_ends, mask_params
- boxes, lines, points, labels, arc_ends, arc_params = get_boxes_lines(objs, shape)
- # print_params(arc_ends, arc_params)
- # if points is not None:
- # target["points"] = points
- # if lines is not None:
- # a = torch.full((lines.shape[0],), 2).unsqueeze(1)
- # lines = torch.cat((lines, a), dim=1)
- # target["lines"] = lines.to(torch.float32).view(-1, 2, 3)
- if lines is not None:
- label_3d = labels.view(-1, 1, 1).expand(-1, 2, -1) # [N] -> [N,2,1]
- line1 = torch.cat([lines, label_3d], dim=-1) # [N,2,3]
- target["lines"] = line1.to(torch.float32)
- if arc_ends is not None:
- target['mask_ends'] = arc_ends
- if arc_params is not None:
- target['mask_params'] = arc_params
- # arc_angles = compute_arc_angles(arc_ends, arc_params)
- # print_params(arc_ends,arc_params)
- arc_masks = []
- for i in range(len(arc_params)):
- mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000),debug=False)
- arc_masks.append(mask)
- # print_params(arc_masks)
- target['circle_masks'] = torch.stack(arc_masks, dim=0)
- # save_full_mask(torch.stack(arc_masks, dim=0), "arc_masks",
- # "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
- # force_save=False,image=image,show_on_image=True)
- target["boxes"] = boxes
- target["labels"] = labels
- # target["boxes"], lines,target["points"], target["labels"] = get_boxes_lines(objs,shape)
- # print(f'lines:{lines}')
- # target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
- # print(f'target points:{target["points"]}')
- # target["lines"] = lines.to(torch.float32).view(-1,2,3)
- # print(f'')
- # print(f'lines:{target["lines"].shape}')
- target["img_size"] = shape
- # validate_keypoints(lines, shape[0], shape[1])
- return target
- def show(self, idx, show_type='all'):
- image, target = self.__getitem__(idx)
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- img = image
- # print(f'img:{img.shape}')
- if show_type == 'arc_yuan_point_ellipse':
- arc_ends = target['mask_ends']
- arc_params = target['mask_params']
- fig, ax = plt.subplots()
- ax.imshow(img.permute(1, 2, 0))
- for params in arc_params:
- if torch.all(params == 0):
- continue
- x, y, a, b, q = params
- theta = np.radians(q)
- phi = np.linspace(0, 2 * np.pi, 500)
- x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
- y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
- plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
- for point2 in arc_ends:
- if torch.all(point2 == 0):
- continue
- ends_np = point2.cpu().numpy()
- ax.plot(ends_np[:, 0], ends_np[:, 1], 'ro', markersize=6, label='Arc Endpoints')
- ax.legend()
- plt.axis('image') # 保持比例一致
- plt.show()
- if show_type == 'circle_masks':
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
- colors="yellow", width=1)
- # arc = target['arc']
- arc_mask = target['circle_masks']
- # print(f'taget circle:{arc.shape}')
- print(f'target circle_masks:{arc_mask.shape}')
- combined = torch.cat(list(arc_mask), dim=1)
- plt.imshow(combined)
- plt.show()
- if show_type == 'circle_masks11':
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
- colors="yellow", width=1)
- circle = target['circles']
- circle_mask = target['circle_masks']
- print(f'taget circle:{circle.shape}')
- print(f'target circle_masks:{circle_mask.shape}')
- plt.imshow(circle_mask.squeeze(0))
- keypoint_img = draw_keypoints(boxed_image, circle, colors='red', width=3)
- # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
- plt.show()
- # if show_type=='lines':
- # keypoint_img=draw_keypoints((img * 255).to(torch.uint8),target['lines'],colors='red',width=3)
- # plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
- # plt.show()
- if show_type == 'points':
- # print(f'points:{target['points'].shape}')
- keypoint_img = draw_keypoints((img * 255).to(torch.uint8), target['points'].unsqueeze(1), colors='red',
- width=3)
- plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
- plt.show()
- if show_type == 'boxes':
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
- colors="yellow", width=1)
- plt.imshow(boxed_image.permute(1, 2, 0).numpy())
- plt.show()
- def show_img(self, img_path):
- pass
- import torch
- import numpy as np
- import cv2
- def arc_to_mask_safe(arc_param, arc_end, shape, line_width=5, debug=True, idx=-1):
- """
- Generate a mask for a small (<180 degree) arc based on arc parameters and endpoints.
- Args:
- arc_param: torch.Tensor of shape (5,) - [cx, cy, a, b, theta]
- arc_end: torch.Tensor of shape (2,2) - [[x1,y1],[x2,y2]]
- shape: tuple (H,W) - mask size
- line_width: thickness of the arc
- debug: bool - if True, print debug info
- idx: int or str - index for debugging identification
- Returns:
- mask: torch.Tensor of shape (H,W)
- """
- # ------------------ Check for all-zero input ------------------
- if torch.all(arc_param == 0) or torch.all(arc_end == 0):
- if debug:
- print(f"[{idx}] Warning: arc_param or arc_end all zeros. Returning zero mask.")
- print(f"[{idx}] arc_param: {arc_param.tolist()}")
- print(f"[{idx}] arc_end: {arc_end.tolist()}")
- return torch.zeros(shape, dtype=torch.float32)
- cx, cy, a, b, theta = arc_param.tolist()
- if a <= 0 or b <= 0:
- if debug:
- print(f"[{idx}] Warning: invalid ellipse axes a={a}, b={b}. Returning zero mask.")
- print(f"[{idx}] arc_param: {arc_param.tolist()}")
- print(f"[{idx}] arc_end: {arc_end.tolist()}")
- return torch.zeros(shape, dtype=torch.float32)
- x1, y1 = arc_end[0].tolist()
- x2, y2 = arc_end[1].tolist()
- cos_t = np.cos(theta)
- sin_t = np.sin(theta)
- def point_to_angle(x, y):
- dx = x - cx
- dy = y - cy
- x_ = cos_t * dx + sin_t * dy
- y_ = -sin_t * dx + cos_t * dy
- return np.arctan2(y_ / b, x_ / a)
- try:
- angle1 = point_to_angle(x1, y1)
- angle2 = point_to_angle(x2, y2)
- except Exception as e:
- if debug:
- print(f"[{idx}] Exception in point_to_angle: {e}")
- print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}")
- return torch.zeros(shape, dtype=torch.float32)
- if np.isnan(angle1) or np.isnan(angle2):
- if debug:
- print(f"[{idx}] Warning: angle1 or angle2 is NaN. Returning zero mask.")
- print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}")
- return torch.zeros(shape, dtype=torch.float32)
- # Ensure small arc (<180 degrees)
- if angle2 < angle1:
- angle2 += 2 * np.pi
- if angle2 - angle1 > np.pi:
- angle1, angle2 = angle2, angle1 + 2 * np.pi
- angles = np.linspace(angle1, angle2, 100)
- xs = cx + a * np.cos(angles) * cos_t - b * np.sin(angles) * sin_t
- ys = cy + a * np.cos(angles) * sin_t + b * np.sin(angles) * cos_t
- xs = np.nan_to_num(xs, nan=0.0).astype(np.int32)
- ys = np.nan_to_num(ys, nan=0.0).astype(np.int32)
- # ------------------ Debug prints ------------------
- if debug:
- print(f"[{idx}] arc_param: {arc_param.tolist()}")
- print(f"[{idx}] arc_end: {arc_end.tolist()}")
- print(f"[{idx}] xs[:5], ys[:5]: {xs[:5]}, {ys[:5]}")
- mask = np.zeros(shape, dtype=np.uint8)
- pts = np.stack([xs, ys], axis=1)
- # Draw the arc with given line_width
- for i in range(len(pts) - 1):
- cv2.line(mask, tuple(pts[i]), tuple(pts[i + 1]), color=1, thickness=line_width)
- # ------------------ Extra check for non-zero mask ------------------
- if debug:
- mask_sum = mask.sum()
- if mask_sum == 0:
- print(f"[{idx}] Warning: mask generated is all zeros!")
- else:
- print(f"[{idx}] mask sum: {mask_sum}")
- return torch.tensor(mask, dtype=torch.float32)
- def draw_el(all):
- # 解析椭圆参数
- if isinstance(all, torch.Tensor):
- all = all.cpu().numpy()
- x, y, a, b, q, q1, q2 = all
- theta = np.radians(q)
- phi1 = np.radians(q1) # 第一个点的参数角
- phi2 = np.radians(q2) # 第二个点的参数角
- # 生成椭圆上的点
- phi = np.linspace(0, 2 * np.pi, 500)
- x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
- y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
- # 计算两个指定点的坐标
- def param_to_point(phi, xc, yc, a, b, theta):
- x = xc + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
- y = yc + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
- return x, y
- P1 = param_to_point(phi1, x, y, a, b, theta)
- P2 = param_to_point(phi2, x, y, a, b, theta)
- # 创建画布并显示背景图片(使用传入的background_img,shape为[H, W, C])
- plt.figure(figsize=(10, 10))
- # plt.imshow(background_img) # 直接显示背景图
- # 绘制椭圆及相关元素
- plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
- plt.plot(x, y, 'ko', markersize=8)
- plt.plot(P1[0], P1[1], 'ro', markersize=10)
- plt.plot(P2[0], P2[1], 'go', markersize=10)
- plt.show()
- def arc_to_mask(arc7, shape, line_width=1):
- """
- Generate a binary mask of an elliptical arc.
- Args:
- xc, yc (float): 椭圆中心
- a, b (float): 长半轴、短半轴 (a >= b)
- theta (float): 椭圆旋转角度(**弧度**,逆时针,相对于 x 轴)
- phi1, phi2 (float): 起始和终止参数角(**弧度**,在 [0, 2π) 内)
- H, W (int): 输出 mask 的高度和宽度
- line_width (int): 弧线宽度(像素)
- Returns:
- mask (Tensor): [H, W], dtype=torch.uint8, 0/255
- """
- # print_params(arc7)
- # 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
- if torch.all(arc7 == 0):
- return torch.zeros(shape, dtype=torch.uint8)
- xc, yc, a, b, theta, phi1, phi2 = arc7
- H, W = shape
- if phi2 < phi1:
- phi2 += 2 * np.pi
- # 生成参数角(足够密集,避免断线)
- num_points = max(int(200 * abs(phi2 - phi1) / (2 * np.pi)), 10)
- phi = np.linspace(phi1, phi2, num_points)
- # 椭圆参数方程(先在未旋转坐标系下计算)
- x_local = a * np.cos(phi)
- y_local = b * np.sin(phi)
- # 应用旋转和平移
- cos_t = np.cos(theta)
- sin_t = np.sin(theta)
- x_rot = x_local * cos_t - y_local * sin_t + xc
- y_rot = x_local * cos_t + y_local * sin_t + yc
- # 转为整数坐标(OpenCV 需要 int32)
- points = np.stack([x_rot, y_rot], axis=1).astype(np.int32)
- # 创建空白图像
- img = np.zeros((H, W), dtype=np.uint8)
- # 绘制折线(antialias=False 更适合 mask)
- cv2.polylines(img, [points], isClosed=False, color=255, thickness=line_width, lineType=cv2.LINE_AA)
- return torch.from_numpy(img).float() # [H, W], values: 0 or 255
- def compute_arc_angles(gt_mask_ends, gt_mask_params):
- """
- 给定椭圆上的一个点,计算其对应的参数角 phi(弧度)。
- Parameters:
- point: tuple or array-like, (x, y)
- ellipse_param: tuple or array-like, (xc, yc, a, b, theta)
- Returns:
- phi: float, in [0, 2*pi)
- """
- # print_params(gt_mask_ends, gt_mask_params)
- results = []
- if not isinstance(gt_mask_params, torch.Tensor):
- gt_mask_params_tensor = torch.tensor(gt_mask_params, dtype=gt_mask_ends.dtype, device=gt_mask_ends.device)
- else:
- gt_mask_params_tensor = gt_mask_params.clone().detach().to(gt_mask_ends)
- for ends_img, params_img in zip(gt_mask_ends, gt_mask_params_tensor):
- # print(f'params_img:{params_img}')
- if torch.norm(params_img) < 1e-6: # L2 norm near zero
- results.append(torch.zeros(2, device=params_img.device, dtype=params_img.dtype))
- continue
- x, y = ends_img
- xc, yc, a, b, theta = params_img
- # 1. 平移到中心
- dx = x - xc
- dy = y - yc
- # 2. 逆旋转(旋转 -theta)
- cos_t = torch.cos(theta)
- sin_t = torch.sin(theta)
- X = dx * cos_t + dy * sin_t
- Y = -dx * sin_t + dy * cos_t
- # 3. 归一化到单位圆(除以 a, b)
- cos_phi = X / a
- sin_phi = Y / b
- # 4. 用 atan2 求角度(自动处理象限)
- phi = torch.atan2(sin_phi, cos_phi)
- # 5. 转换到 [0, 2π)
- phi = torch.where(phi < 0, phi + 2 * torch.pi, phi)
- results.append(phi)
- return results
- def points_to_ellipse(points):
- """
- æ ¹æ®æä¾çå个ç¹ä¼°è®¡æ¤ååæ°ã
- :param points: Tensor of shape (4, 2) 表示æ¤åä¸çå个ç¹
- :return: è¿å (cx, cy, r1, r2, orientation) å
¶ä¸ cx, cy æ¯ä¸å¿åæ ï¼r1, r2 å嫿¯é¿è½´åçè½´åå¾ï¼orientation æ¯æ¤åçæ¹åï¼å¼§åº¦ï¼
- """
- # 转æ¢ä¸ºnumpyæ°ç»è¿è¡è®¡ç®
- pts = points.numpy()
- pts = pts.reshape(-1, 2)
- center = np.mean(pts, axis=0)
- A = np.hstack(
- [pts[:, 0:1] ** 2, pts[:, 0:1] * pts[:, 1:2], pts[:, 1:2] ** 2, pts[:, :2], np.ones((pts.shape[0], 1))])
- b = np.ones(pts.shape[0])
- x = np.linalg.lstsq(A, b, rcond=None)[0]
- # è§£æè§£åè§ https://en.wikipedia.org/wiki/Ellipse#General_ellipse
- a, b, c, d, f, g = x.ravel()
- numerator = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g)
- denominator1 = (b * b - a * c) * ((c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a))
- denominator2 = (b * b - a * c) * ((a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a))
- major_axis = np.sqrt(numerator / denominator1)
- minor_axis = np.sqrt(numerator / denominator2)
- distances = np.linalg.norm(pts - center, axis=1)
- long_axis_length = np.max(distances) * 2
- short_axis_length = np.min(distances) * 2
- orientation = np.arctan2(pts[1, 1] - pts[0, 1], pts[1, 0] - pts[0, 0])
- return center[0], center[1], long_axis_length / 2, short_axis_length / 2, orientation
- def generate_ellipse_mask(shape, ellipse_params):
- """
- 卿å®å½¢ç¶çå¾åä¸çææ¤åmaskã
- :param shape: è¾åºmaskçå½¢ç¶ (HxW)
- :param ellipse_params: æ¤ååæ° (cx, cy, rx, ry, orientation)
- :return: æ¤åmask
- """
- cx, cy, rx, ry, orientation = ellipse_params
- img = np.zeros(shape, dtype=np.uint8)
- cx, cy, rx, ry = int(cx), int(cy), int(rx), int(ry)
- rr, cc = ellipse(cy, cx, ry, rx, shape)
- img[rr, cc] = 1
- return img
- def sort_points_clockwise(points):
- points = np.array(points)
- top_left_idx = np.lexsort((points[:, 0], points[:, 1]))[0]
- reference_point = points[top_left_idx]
- def angle_to_reference(point):
- return np.arctan2(point[1] - reference_point[1], point[0] - reference_point[0])
- angles = np.apply_along_axis(angle_to_reference, 1, points)
- angles[angles < 0] += 2 * np.pi
- sorted_indices = np.argsort(angles)
- sorted_points = points[sorted_indices]
- return sorted_points.tolist()
- def get_boxes_lines(objs, shape):
- boxes = []
- labels = []
- h, w = shape
- line_point_pairs = []
- points = []
- mask_ends = []
- mask_params = []
- for obj in objs:
- # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
- # print(f"points:{obj['points']}")
- label = obj['label']
- if label == 'line' or label == 'dseam1':
- a, b = obj['points'][0], obj['points'][1]
- # line_point_pairs.append(a)
- # line_point_pairs.append(b)
- line_point_pairs.append([a, b])
- xmin = max(0, (min(a[0], b[0]) - 6))
- xmax = min(w, (max(a[0], b[0]) + 6))
- ymin = max(0, (min(a[1], b[1]) - 6))
- ymax = min(h, (max(a[1], b[1]) + 6))
- boxes.append([xmin, ymin, xmax, ymax])
- labels.append(torch.tensor(2))
- points.append(torch.tensor([0.0]))
- mask_ends.append([[0, 0], [0, 0]])
- mask_params.append([0, 0, 0, 0, 0])
- # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
- elif label == 'point':
- p = obj['points'][0]
- xmin = max(0, p[0] - 12)
- xmax = min(w, p[0] + 12)
- ymin = max(0, p[1] - 12)
- ymax = min(h, p[1] + 12)
- points.append(p)
- labels.append(torch.tensor(1))
- boxes.append([xmin, ymin, xmax, ymax])
- line_point_pairs.append([[0, 0], [0, 0]])
- mask_ends.append([[0, 0], [0, 0]])
- mask_params.append([0, 0, 0, 0, 0])
- # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
- # elif label == 'arc':
- # arc_points = obj['points']
- # arc_params = obj['params']
- # arc_ends = obj['ends']
- # line_mask.append(arc_points)
- # mask_ends.append(arc_ends)
- # mask_params.append(arc_params)
- #
- # xs = [p[0] for p in arc_points]
- # ys = [p[1] for p in arc_points]
- # xmin, xmax = min(xs), max(xs)
- # ymin, ymax = min(ys), max(ys)
- #
- # boxes.append([xmin, ymin, xmax, ymax])
- # labels.append(torch.tensor(3))
- #
- # points.append(torch.tensor([0.0]))
- # line_point_pairs.append([[0, 0], [0, 0]])
- # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
- elif label == 'arc':
- arc_params = obj['params']
- arc_ends = obj['ends']
- mask_ends.append(arc_ends)
- mask_params.append(arc_params)
- arc3points = obj['points']
- xs = [p[0] for p in arc3points]
- ys = [p[1] for p in arc3points]
- xmin_raw = min(xs)
- xmax_raw = max(xs)
- ymin_raw = min(ys)
- ymax_raw = max(ys)
- xmin = max(xmin_raw - 40, 0)
- xmax = min(xmax_raw + 40, w)
- ymin = max(ymin_raw - 40, 0)
- ymax = min(ymax_raw + 40, h)
- boxes.append([xmin, ymin, xmax, ymax])
- labels.append(torch.tensor(4))
- points.append(torch.tensor([0.0]))
- line_point_pairs.append([[0, 0], [0, 0]])
- boxes = torch.tensor(boxes, dtype=torch.float32)
- print(f'boxes:{boxes.shape}')
- labels = torch.tensor(labels)
- if points:
- points = torch.tensor(points, dtype=torch.float32)
- else:
- points = None
- if line_point_pairs:
- line_point_pairs = torch.tensor(line_point_pairs, dtype=torch.float32)
- else:
- line_point_pairs = None
- if mask_ends:
- mask_ends = torch.tensor(mask_ends, dtype=torch.float32)
- else:
- mask_ends = None
- if mask_params:
- mask_params = torch.tensor(mask_params, dtype=torch.float32)
- else:
- mask_params = None
- return boxes, line_point_pairs, points, labels, mask_ends, mask_params
- if __name__ == '__main__':
- path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
- dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
- dataset.show(19, show_type='arc_yuan_point_ellipse')
|