|
|
@@ -71,7 +71,7 @@ class LineDataset(BaseDataset):
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
w, h = img.size
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
+ target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
|
|
|
|
|
|
self.transforms = get_transforms(augmention=self.augmentation)
|
|
|
|
|
|
@@ -82,7 +82,7 @@ class LineDataset(BaseDataset):
|
|
|
def __len__(self):
|
|
|
return len(self.imgs)
|
|
|
|
|
|
- def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
+ def read_target(self, item, lbl_path, shape, extra=None,image=None):
|
|
|
# print(f'shape:{shape}')
|
|
|
# print(f'lbl_path:{lbl_path}')
|
|
|
with open(lbl_path, 'r') as file:
|
|
|
@@ -118,30 +118,21 @@ class LineDataset(BaseDataset):
|
|
|
target['mask_params'] = arc_params
|
|
|
|
|
|
|
|
|
- arc_angles = compute_arc_angles(arc_ends, arc_params)
|
|
|
- # print_params(arc_angles)
|
|
|
- arc_masks = []
|
|
|
|
|
|
+ # arc_angles = compute_arc_angles(arc_ends, arc_params)
|
|
|
|
|
|
|
|
|
+ print_params(arc_ends,arc_params)
|
|
|
+ arc_masks = []
|
|
|
for i in range(len(arc_params)):
|
|
|
- arc_param_i = arc_params[i].view(-1) # shape (5,)
|
|
|
- arc_angle_i = arc_angles[i].view(-1) # shape (2,)
|
|
|
- arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0) # shape (7,)
|
|
|
-
|
|
|
-
|
|
|
- # print_params(arc7)
|
|
|
- mask = arc_to_mask(arc7, shape, line_width=1)
|
|
|
-
|
|
|
+ mask = arc_to_mask_safe(arc_params[i], arc_ends[i], shape=(2000, 2000))
|
|
|
arc_masks.append(mask)
|
|
|
- # arc7=arc_params[i] + arc_angles[i].tolist()
|
|
|
- # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
|
|
|
-
|
|
|
- # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
|
|
|
-
|
|
|
+ print_params(arc_masks)
|
|
|
target['circle_masks'] = torch.stack(arc_masks, dim=0)
|
|
|
- save_full_mask(target['circle_masks'], "arc_masks",
|
|
|
- "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
|
|
|
+
|
|
|
+ # save_full_mask(torch.stack(arc_masks, dim=0), "arc_masks",
|
|
|
+ # "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
|
|
|
+ # force_save=False,image=image,show_on_image=True)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -249,6 +240,109 @@ class LineDataset(BaseDataset):
|
|
|
pass
|
|
|
|
|
|
|
|
|
+import torch
|
|
|
+import numpy as np
|
|
|
+import cv2
|
|
|
+
|
|
|
+def arc_to_mask_safe(arc_param, arc_end, shape, line_width=5, debug=True, idx=-1):
|
|
|
+ """
|
|
|
+ Generate a mask for a small (<180 degree) arc based on arc parameters and endpoints.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ arc_param: torch.Tensor of shape (5,) - [cx, cy, a, b, theta]
|
|
|
+ arc_end: torch.Tensor of shape (2,2) - [[x1,y1],[x2,y2]]
|
|
|
+ shape: tuple (H,W) - mask size
|
|
|
+ line_width: thickness of the arc
|
|
|
+ debug: bool - if True, print debug info
|
|
|
+ idx: int or str - index for debugging identification
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ mask: torch.Tensor of shape (H,W)
|
|
|
+ """
|
|
|
+
|
|
|
+ # ------------------ Check for all-zero input ------------------
|
|
|
+ if torch.all(arc_param == 0) or torch.all(arc_end == 0):
|
|
|
+ if debug:
|
|
|
+ print(f"[{idx}] Warning: arc_param or arc_end all zeros. Returning zero mask.")
|
|
|
+ print(f"[{idx}] arc_param: {arc_param.tolist()}")
|
|
|
+ print(f"[{idx}] arc_end: {arc_end.tolist()}")
|
|
|
+ return torch.zeros(shape, dtype=torch.float32)
|
|
|
+
|
|
|
+ cx, cy, a, b, theta = arc_param.tolist()
|
|
|
+
|
|
|
+ if a <= 0 or b <= 0:
|
|
|
+ if debug:
|
|
|
+ print(f"[{idx}] Warning: invalid ellipse axes a={a}, b={b}. Returning zero mask.")
|
|
|
+ print(f"[{idx}] arc_param: {arc_param.tolist()}")
|
|
|
+ print(f"[{idx}] arc_end: {arc_end.tolist()}")
|
|
|
+ return torch.zeros(shape, dtype=torch.float32)
|
|
|
+
|
|
|
+ x1, y1 = arc_end[0].tolist()
|
|
|
+ x2, y2 = arc_end[1].tolist()
|
|
|
+
|
|
|
+ cos_t = np.cos(theta)
|
|
|
+ sin_t = np.sin(theta)
|
|
|
+
|
|
|
+ def point_to_angle(x, y):
|
|
|
+ dx = x - cx
|
|
|
+ dy = y - cy
|
|
|
+ x_ = cos_t * dx + sin_t * dy
|
|
|
+ y_ = -sin_t * dx + cos_t * dy
|
|
|
+ return np.arctan2(y_ / b, x_ / a)
|
|
|
+
|
|
|
+ try:
|
|
|
+ angle1 = point_to_angle(x1, y1)
|
|
|
+ angle2 = point_to_angle(x2, y2)
|
|
|
+ except Exception as e:
|
|
|
+ if debug:
|
|
|
+ print(f"[{idx}] Exception in point_to_angle: {e}")
|
|
|
+ print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}")
|
|
|
+ return torch.zeros(shape, dtype=torch.float32)
|
|
|
+
|
|
|
+ if np.isnan(angle1) or np.isnan(angle2):
|
|
|
+ if debug:
|
|
|
+ print(f"[{idx}] Warning: angle1 or angle2 is NaN. Returning zero mask.")
|
|
|
+ print(f"[{idx}] arc_param: {arc_param.tolist()}, arc_end: {arc_end.tolist()}")
|
|
|
+ return torch.zeros(shape, dtype=torch.float32)
|
|
|
+
|
|
|
+ # Ensure small arc (<180 degrees)
|
|
|
+ if angle2 < angle1:
|
|
|
+ angle2 += 2 * np.pi
|
|
|
+ if angle2 - angle1 > np.pi:
|
|
|
+ angle1, angle2 = angle2, angle1 + 2 * np.pi
|
|
|
+
|
|
|
+ angles = np.linspace(angle1, angle2, 100)
|
|
|
+ xs = cx + a * np.cos(angles) * cos_t - b * np.sin(angles) * sin_t
|
|
|
+ ys = cy + a * np.cos(angles) * sin_t + b * np.sin(angles) * cos_t
|
|
|
+
|
|
|
+ xs = np.nan_to_num(xs, nan=0.0).astype(np.int32)
|
|
|
+ ys = np.nan_to_num(ys, nan=0.0).astype(np.int32)
|
|
|
+
|
|
|
+ # ------------------ Debug prints ------------------
|
|
|
+ if debug:
|
|
|
+ print(f"[{idx}] arc_param: {arc_param.tolist()}")
|
|
|
+ print(f"[{idx}] arc_end: {arc_end.tolist()}")
|
|
|
+ print(f"[{idx}] xs[:5], ys[:5]: {xs[:5]}, {ys[:5]}")
|
|
|
+
|
|
|
+ mask = np.zeros(shape, dtype=np.uint8)
|
|
|
+ pts = np.stack([xs, ys], axis=1)
|
|
|
+
|
|
|
+ # Draw the arc with given line_width
|
|
|
+ for i in range(len(pts) - 1):
|
|
|
+ cv2.line(mask, tuple(pts[i]), tuple(pts[i + 1]), color=1, thickness=line_width)
|
|
|
+
|
|
|
+ # ------------------ Extra check for non-zero mask ------------------
|
|
|
+ if debug:
|
|
|
+ mask_sum = mask.sum()
|
|
|
+ if mask_sum == 0:
|
|
|
+ print(f"[{idx}] Warning: mask generated is all zeros!")
|
|
|
+ else:
|
|
|
+ print(f"[{idx}] mask sum: {mask_sum}")
|
|
|
+
|
|
|
+ return torch.tensor(mask, dtype=torch.float32)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
def draw_el(all):
|
|
|
# 解析椭圆参数
|
|
|
if isinstance(all, torch.Tensor):
|
|
|
@@ -585,6 +679,6 @@ def get_boxes_lines(objs, shape):
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- path = r'\\192.168.50.222/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
|
|
|
+ path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
|
|
|
dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
|
|
|
dataset.show(19, show_type='arc_yuan_point_ellipse')
|