|
|
@@ -71,7 +71,7 @@ class LineDataset(BaseDataset):
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
w, h = img.size
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
- target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w),image=img)
|
|
|
+ target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
|
|
|
self.transforms = get_transforms(augmention=self.augmentation)
|
|
|
|
|
|
@@ -82,7 +82,7 @@ class LineDataset(BaseDataset):
|
|
|
def __len__(self):
|
|
|
return len(self.imgs)
|
|
|
|
|
|
- def read_target(self, item, lbl_path, shape,extra=None,image=None):
|
|
|
+ def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
# print(f'shape:{shape}')
|
|
|
# print(f'lbl_path:{lbl_path}')
|
|
|
with open(lbl_path, 'r') as file:
|
|
|
@@ -122,30 +122,26 @@ class LineDataset(BaseDataset):
|
|
|
# print_params(arc_angles)
|
|
|
arc_masks = []
|
|
|
|
|
|
+
|
|
|
+
|
|
|
for i in range(len(arc_params)):
|
|
|
arc_param_i = arc_params[i].view(-1) # shape (5,)
|
|
|
arc_angle_i = arc_angles[i].view(-1) # shape (2,)
|
|
|
arc7 = torch.cat([arc_param_i, arc_angle_i], dim=0) # shape (7,)
|
|
|
|
|
|
|
|
|
+ # print_params(arc7)
|
|
|
mask = arc_to_mask(arc7, shape, line_width=1)
|
|
|
|
|
|
arc_masks.append(mask)
|
|
|
+ # arc7=arc_params[i] + arc_angles[i].tolist()
|
|
|
+ # arc_masks.append(arc_to_mask(arc7, shape, line_width=1))
|
|
|
|
|
|
- print_params(arc_masks,image)
|
|
|
+ # print(f'circle_masks:{torch.stack(arc_masks, dim=0).shape}')
|
|
|
|
|
|
target['circle_masks'] = torch.stack(arc_masks, dim=0)
|
|
|
-
|
|
|
- # for i, m in enumerate(target['circle_masks']):
|
|
|
- # save_full_mask(
|
|
|
- # m,
|
|
|
- # name=f"arc_mask_{i}",
|
|
|
- # out_dir=r"/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset",
|
|
|
- # save_png=True,
|
|
|
- # save_npy=True,
|
|
|
- # image=image,
|
|
|
- # show_on_image=True
|
|
|
- # )
|
|
|
+ save_full_mask(target['circle_masks'], "arc_masks",
|
|
|
+ "/home/zhaoyinghan/py_ws/code/circle_huayan/MultiVisionModels/models/line_detect/out_feature_dataset")
|
|
|
|
|
|
|
|
|
|
|
|
@@ -177,9 +173,36 @@ class LineDataset(BaseDataset):
|
|
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
sm.set_array([])
|
|
|
|
|
|
- # img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- # print(f'boxes:{target["boxes"]}')
|
|
|
img = image
|
|
|
+ # print(f'img:{img.shape}')
|
|
|
+
|
|
|
+ if show_type == 'arc_yuan_point_ellipse':
|
|
|
+ arc_ends = target['mask_ends']
|
|
|
+ arc_params = target['mask_params']
|
|
|
+
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ ax.imshow(img.permute(1, 2, 0))
|
|
|
+
|
|
|
+ for params in arc_params:
|
|
|
+ if torch.all(params == 0):
|
|
|
+ continue
|
|
|
+ x, y, a, b, q = params
|
|
|
+ theta = np.radians(q)
|
|
|
+ phi = np.linspace(0, 2 * np.pi, 500)
|
|
|
+ x_ellipse = x + a * np.cos(phi) * np.cos(theta) - b * np.sin(phi) * np.sin(theta)
|
|
|
+ y_ellipse = y + a * np.cos(phi) * np.sin(theta) + b * np.sin(phi) * np.cos(theta)
|
|
|
+
|
|
|
+ plt.plot(x_ellipse, y_ellipse, 'b-', linewidth=2)
|
|
|
+
|
|
|
+ for point2 in arc_ends:
|
|
|
+ if torch.all(point2 == 0):
|
|
|
+ continue
|
|
|
+ ends_np = point2.cpu().numpy()
|
|
|
+ ax.plot(ends_np[:, 0], ends_np[:, 1], 'ro', markersize=6, label='Arc Endpoints')
|
|
|
+
|
|
|
+ ax.legend()
|
|
|
+ plt.axis('image') # 保持比例一致
|
|
|
+ plt.show()
|
|
|
|
|
|
if show_type == 'circle_masks':
|
|
|
boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), target["boxes"],
|
|
|
@@ -281,7 +304,6 @@ def arc_to_mask(arc7, shape, line_width=1):
|
|
|
# 确保 phi1 -> phi2 是正向(可处理跨 2π 的情况)
|
|
|
if torch.all(arc7 == 0):
|
|
|
return torch.zeros(shape, dtype=torch.uint8)
|
|
|
- print_params(arc7)
|
|
|
|
|
|
xc, yc, a, b, theta, phi1, phi2 = arc7
|
|
|
H, W = shape
|
|
|
@@ -559,14 +581,10 @@ def get_boxes_lines(objs, shape):
|
|
|
mask_params = None
|
|
|
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
return boxes, line_point_pairs, points, labels, mask_ends, mask_params
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- path = r'/data/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
|
|
|
+ path = r'\\192.168.50.222/share/zyh/master_dataset/pokou/251115/a_dataset_pokou_mask'
|
|
|
dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
|
|
|
- dataset.show(9, show_type='circle_masks')
|
|
|
+ dataset.show(19, show_type='arc_yuan_point_ellipse')
|