|
|
@@ -98,7 +98,6 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
target["image_id"] = torch.tensor(item)
|
|
|
|
|
|
- #boxes, line_point_pairs, points, labels, mask_ends, mask_params
|
|
|
boxes, lines, points, labels, arc_ends, arc_params = get_boxes_lines(objs, shape)
|
|
|
|
|
|
if lines is not None:
|
|
|
@@ -120,16 +119,13 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
print(f'target[circle_masks]:{target["circle_masks"].shape}')
|
|
|
|
|
|
-
|
|
|
target["boxes"] = boxes
|
|
|
target["labels"] = labels
|
|
|
target["img_size"] = shape
|
|
|
|
|
|
- # validate_keypoints(lines, shape[0], shape[1])
|
|
|
return target
|
|
|
|
|
|
|
|
|
-
|
|
|
def show(self, idx, show_type='all'):
|
|
|
image, target = self.__getitem__(idx)
|
|
|
|
|
|
@@ -141,6 +137,23 @@ class LineDataset(BaseDataset):
|
|
|
img = image
|
|
|
# print(f'img:{img.shape}')
|
|
|
|
|
|
+ if show_type == 'Original_mask':
|
|
|
+
|
|
|
+ image = image.permute(1, 2, 0).cpu().numpy() # (3,H,W) -> (H,W,3)
|
|
|
+ arc_mask = target['circle_masks'][-1]
|
|
|
+
|
|
|
+ plt.figure(figsize=(10, 5))
|
|
|
+ plt.subplot(1, 2, 1)
|
|
|
+ plt.imshow(image)
|
|
|
+ plt.title('Original')
|
|
|
+
|
|
|
+ # 叠加 mask(用红色轮廓或填充)
|
|
|
+ plt.subplot(1, 2, 2)
|
|
|
+ plt.imshow(image)
|
|
|
+ plt.imshow(arc_mask, cmap='Reds', alpha=0.5) # 半透明红色
|
|
|
+ plt.title('Image with Mask')
|
|
|
+ plt.show()
|
|
|
+
|
|
|
if show_type == 'arc_yuan_point_ellipse':
|
|
|
arc_ends = target['mask_ends']
|
|
|
arc_params = target['mask_params']
|
|
|
@@ -171,10 +184,7 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
if show_type == 'circle_masks':
|
|
|
arc_mask = target['circle_masks']
|
|
|
- # print(f'taget circle:{arc.shape}')
|
|
|
print(f'target circle_masks:{arc_mask.shape}')
|
|
|
- combined = torch.cat(list(arc_mask), dim=1)
|
|
|
- print(f'combine:{combined.shape}')
|
|
|
plt.imshow(arc_mask[-1])
|
|
|
plt.show()
|
|
|
|
|
|
@@ -314,7 +324,6 @@ def arc_to_mask_safe(arc_param, arc_end, shape, line_width=5, debug=True, idx=-1
|
|
|
return torch.tensor(mask, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
-
|
|
|
def draw_el(all):
|
|
|
# 解析椭圆参数
|
|
|
if isinstance(all, torch.Tensor):
|
|
|
@@ -535,47 +544,41 @@ def get_boxes_lines(objs, shape):
|
|
|
mask_params = []
|
|
|
|
|
|
for obj in objs:
|
|
|
- # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
|
|
|
- # print(f"points:{obj['points']}")
|
|
|
label = obj['label']
|
|
|
- if label == 'line' or label == 'dseam1':
|
|
|
- a, b = obj['points'][0], obj['points'][1]
|
|
|
-
|
|
|
- # line_point_pairs.append(a)
|
|
|
- # line_point_pairs.append(b)
|
|
|
- line_point_pairs.append([a, b])
|
|
|
-
|
|
|
- xmin = max(0, (min(a[0], b[0]) - 6))
|
|
|
- xmax = min(w, (max(a[0], b[0]) + 6))
|
|
|
- ymin = max(0, (min(a[1], b[1]) - 6))
|
|
|
- ymax = min(h, (max(a[1], b[1]) + 6))
|
|
|
-
|
|
|
- boxes.append([xmin, ymin, xmax, ymax])
|
|
|
- labels.append(torch.tensor(2))
|
|
|
-
|
|
|
- points.append(torch.tensor([0.0]))
|
|
|
- mask_ends.append([[0, 0], [0, 0]])
|
|
|
- mask_params.append([0, 0, 0, 0, 0])
|
|
|
- # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- elif label == 'point':
|
|
|
- p = obj['points'][0]
|
|
|
- xmin = max(0, p[0] - 12)
|
|
|
- xmax = min(w, p[0] + 12)
|
|
|
- ymin = max(0, p[1] - 12)
|
|
|
- ymax = min(h, p[1] + 12)
|
|
|
-
|
|
|
- points.append(p)
|
|
|
- labels.append(torch.tensor(1))
|
|
|
- boxes.append([xmin, ymin, xmax, ymax])
|
|
|
-
|
|
|
- line_point_pairs.append([[0, 0], [0, 0]])
|
|
|
- mask_ends.append([[0, 0], [0, 0]])
|
|
|
- mask_params.append([0, 0, 0, 0, 0])
|
|
|
- # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
|
|
|
+ # if label == 'line' or label == 'dseam1':
|
|
|
+ # a, b = obj['points'][0], obj['points'][1]
|
|
|
+ # line_point_pairs.append([a, b])
|
|
|
+ #
|
|
|
+ # xmin = max(0, (min(a[0], b[0]) - 6))
|
|
|
+ # xmax = min(w, (max(a[0], b[0]) + 6))
|
|
|
+ # ymin = max(0, (min(a[1], b[1]) - 6))
|
|
|
+ # ymax = min(h, (max(a[1], b[1]) + 6))
|
|
|
+ #
|
|
|
+ # boxes.append([xmin, ymin, xmax, ymax])
|
|
|
+ # labels.append(torch.tensor(2))
|
|
|
+ #
|
|
|
+ # points.append(torch.tensor([0.0]))
|
|
|
+ # mask_ends.append([[0, 0], [0, 0]])
|
|
|
+ # mask_params.append([0, 0, 0, 0, 0])
|
|
|
+ # # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
|
|
|
+ #
|
|
|
+ #
|
|
|
+ # elif label == 'point':
|
|
|
+ # p = obj['points'][0]
|
|
|
+ # xmin = max(0, p[0] - 12)
|
|
|
+ # xmax = min(w, p[0] + 12)
|
|
|
+ # ymin = max(0, p[1] - 12)
|
|
|
+ # ymax = min(h, p[1] + 12)
|
|
|
+ #
|
|
|
+ # points.append(p)
|
|
|
+ # labels.append(torch.tensor(1))
|
|
|
+ # boxes.append([xmin, ymin, xmax, ymax])
|
|
|
+ #
|
|
|
+ # line_point_pairs.append([[0, 0], [0, 0]])
|
|
|
+ # mask_ends.append([[0, 0], [0, 0]])
|
|
|
+ # mask_params.append([0, 0, 0, 0, 0])
|
|
|
+ # # circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
|
|
|
|
|
|
|
|
|
# elif label == 'arc':
|
|
|
@@ -598,7 +601,7 @@ def get_boxes_lines(objs, shape):
|
|
|
# line_point_pairs.append([[0, 0], [0, 0]])
|
|
|
# circle_4points.append([[0, 0], [0, 0], [0, 0], [0, 0]])
|
|
|
|
|
|
- elif label == 'arc':
|
|
|
+ if label == 'arc':
|
|
|
|
|
|
arc_params = obj['params']
|
|
|
arc_ends = obj['ends']
|
|
|
@@ -653,4 +656,5 @@ def get_boxes_lines(objs, shape):
|
|
|
if __name__ == '__main__':
|
|
|
path = r'/data/share/zyh/master_dataset/dataset_net/pokou_251115_251121/a_dataset'
|
|
|
dataset = LineDataset(dataset_path=path, dataset_type='train', augmentation=False, data_type='jpg')
|
|
|
- dataset.show(19, show_type='circle_masks')
|
|
|
+ for i in range(100):
|
|
|
+ dataset.show(i, show_type='Original_mask')
|