|
@@ -31,7 +31,10 @@ def validate_keypoints(keypoints, image_width, image_height):
|
|
|
if not (0 <= x < image_width and 0 <= y < image_height):
|
|
if not (0 <= x < image_width and 0 <= y < image_height):
|
|
|
raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
|
|
raise ValueError(f"Key point ({x}, {y}) is out of bounds for image size ({image_width}, {image_height})")
|
|
|
|
|
|
|
|
|
|
+"""
|
|
|
|
|
+直接读取xanlabel标注的数据集json格式
|
|
|
|
|
|
|
|
|
|
+"""
|
|
|
class LineDataset(BaseDataset):
|
|
class LineDataset(BaseDataset):
|
|
|
def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
|
|
def __init__(self, dataset_path, data_type, transforms=None, dataset_type=None,img_type='rgb', target_type='pixel'):
|
|
|
super().__init__(dataset_path)
|
|
super().__init__(dataset_path)
|
|
@@ -50,28 +53,20 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
|
|
def __getitem__(self, index) -> T_co:
|
|
def __getitem__(self, index) -> T_co:
|
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
img_path = os.path.join(self.img_path, self.imgs[index])
|
|
|
- if self.data_type == 'tiff':
|
|
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
|
|
|
|
|
- # img = imageio.v3.imread(img_path).reshape(512, 512, 1)
|
|
|
|
|
- img = imageio.v3.imread(img_path)[:, :, :3]
|
|
|
|
|
- # img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
|
|
|
|
|
- # img_3channel[:, :, 2] = img[:, :, 0]
|
|
|
|
|
-
|
|
|
|
|
- img_3channel=img
|
|
|
|
|
- w, h = img.shape[:2]
|
|
|
|
|
- img = torch.from_numpy(img_3channel).permute(2, 0, 1)
|
|
|
|
|
- else:
|
|
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
- w, h = img.size
|
|
|
|
|
|
|
+
|
|
|
|
|
+ lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-3] + 'json')
|
|
|
|
|
+ img = PIL.Image.open(img_path).convert('RGB')
|
|
|
|
|
+ w, h = img.size
|
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
# wire_labels, target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
target = self.read_target(item=index, lbl_path=lbl_path, shape=(h, w))
|
|
|
if self.transforms:
|
|
if self.transforms:
|
|
|
img, target = self.transforms(img, target)
|
|
img, target = self.transforms(img, target)
|
|
|
|
|
+
|
|
|
else:
|
|
else:
|
|
|
img = self.default_transform(img)
|
|
img = self.default_transform(img)
|
|
|
|
|
|
|
|
# print(f'img:{img}')
|
|
# print(f'img:{img}')
|
|
|
|
|
+ # print(f'img shape:{img.shape}')
|
|
|
return img, target
|
|
return img, target
|
|
|
|
|
|
|
|
def __len__(self):
|
|
def __len__(self):
|
|
@@ -83,78 +78,31 @@ class LineDataset(BaseDataset):
|
|
|
with open(lbl_path, 'r') as file:
|
|
with open(lbl_path, 'r') as file:
|
|
|
lable_all = json.load(file)
|
|
lable_all = json.load(file)
|
|
|
|
|
|
|
|
- n_stc_posl = 300
|
|
|
|
|
- n_stc_negl = 40
|
|
|
|
|
- use_cood = 0
|
|
|
|
|
- use_slop = 0
|
|
|
|
|
-
|
|
|
|
|
- wire = lable_all["wires"][0] # 字典
|
|
|
|
|
- line_pos_coords = np.random.permutation(wire["line_pos_coords"]["content"])[: n_stc_posl] # 不足,有多少取多少
|
|
|
|
|
- line_neg_coords = np.random.permutation(wire["line_neg_coords"]["content"])[: n_stc_negl]
|
|
|
|
|
- npos, nneg = len(line_pos_coords), len(line_neg_coords)
|
|
|
|
|
- lpre = np.concatenate([line_pos_coords, line_neg_coords], 0) # 正负样本坐标合在一起
|
|
|
|
|
- for i in range(len(lpre)):
|
|
|
|
|
- if random.random() > 0.5:
|
|
|
|
|
- lpre[i] = lpre[i, ::-1]
|
|
|
|
|
- ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
|
|
|
|
|
- ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
|
|
|
|
|
- feat = [
|
|
|
|
|
- lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
|
|
|
|
|
- ldir * use_slop,
|
|
|
|
|
- lpre[:, :, 2],
|
|
|
|
|
- ]
|
|
|
|
|
- feat = np.concatenate(feat, 1)
|
|
|
|
|
-
|
|
|
|
|
- wire_labels = {
|
|
|
|
|
- "junc_coords": torch.tensor(wire["junc_coords"]["content"]),
|
|
|
|
|
- "jtyp": torch.tensor(wire["junc_coords"]["content"])[:, 2].byte(),
|
|
|
|
|
- "line_pos_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_pos_idx"]["content"]),
|
|
|
|
|
- # 真实存在线条的邻接矩阵
|
|
|
|
|
- "line_neg_idx": adjacency_matrix(len(wire["junc_coords"]["content"]), wire["line_neg_idx"]["content"]),
|
|
|
|
|
-
|
|
|
|
|
- "lpre": torch.tensor(lpre)[:, :, :2],
|
|
|
|
|
- "lpre_label": torch.cat([torch.ones(npos), torch.zeros(nneg)]), # 样本对应标签 1,0
|
|
|
|
|
- "lpre_feat": torch.from_numpy(feat),
|
|
|
|
|
- "junc_map": torch.tensor(wire['junc_map']["content"]),
|
|
|
|
|
- "junc_offset": torch.tensor(wire['junc_offset']["content"]),
|
|
|
|
|
- "line_map": torch.tensor(wire['line_map']["content"]),
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- labels = []
|
|
|
|
|
- if self.target_type == 'polygon':
|
|
|
|
|
- labels, masks = read_masks_from_txt_wire(lbl_path, shape)
|
|
|
|
|
- elif self.target_type == 'pixel':
|
|
|
|
|
- labels = read_masks_from_pixels_wire(lbl_path, shape)
|
|
|
|
|
-
|
|
|
|
|
- # print(torch.stack(masks).shape) # [线段数, 512, 512]
|
|
|
|
|
|
|
+
|
|
|
|
|
+ objs = lable_all["shapes"]
|
|
|
|
|
+ point_pairs=objs[0]['points']
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ # print(f'point_pairs:{point_pairs}')
|
|
|
target = {}
|
|
target = {}
|
|
|
|
|
|
|
|
target["image_id"] = torch.tensor(item)
|
|
target["image_id"] = torch.tensor(item)
|
|
|
- # return wire_labels, target
|
|
|
|
|
- target["wires"] = wire_labels
|
|
|
|
|
-
|
|
|
|
|
- # target["labels"] = torch.stack(labels)
|
|
|
|
|
|
|
|
|
|
- # print(f'labels:{target["labels"]}')
|
|
|
|
|
- # target["boxes"] = line_boxes(target)
|
|
|
|
|
- target["boxes"], lines = get_boxes_lines(target)
|
|
|
|
|
|
|
+ target["boxes"], lines = get_boxes_lines(objs,shape)
|
|
|
|
|
+ # print(f'lines:{lines}')
|
|
|
target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
|
- # keypoints=keypoints/512
|
|
|
|
|
- # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
|
|
|
|
|
|
|
|
|
|
- # keypoints= wire_labels["junc_coords"]
|
|
|
|
|
|
|
+
|
|
|
a = torch.full((lines.shape[0],), 2).unsqueeze(1)
|
|
a = torch.full((lines.shape[0],), 2).unsqueeze(1)
|
|
|
lines = torch.cat((lines, a), dim=1)
|
|
lines = torch.cat((lines, a), dim=1)
|
|
|
|
|
|
|
|
target["lines"] = lines.to(torch.float32).view(-1,2,3)
|
|
target["lines"] = lines.to(torch.float32).view(-1,2,3)
|
|
|
- # print(f'boxes:{target["boxes"].shape}')
|
|
|
|
|
- # 在 __getitem__ 方法中调用此函数
|
|
|
|
|
|
|
+ target["img_size"]=shape
|
|
|
|
|
+
|
|
|
validate_keypoints(lines, shape[0], shape[1])
|
|
validate_keypoints(lines, shape[0], shape[1])
|
|
|
- # print(f'keypoints:{target["keypoints"].shape}')
|
|
|
|
|
- # print(f'target:{target}')
|
|
|
|
|
return target
|
|
return target
|
|
|
|
|
|
|
|
- def show(self, idx):
|
|
|
|
|
|
|
+ def show(self, idx,show_type='all'):
|
|
|
image, target = self.__getitem__(idx)
|
|
image, target = self.__getitem__(idx)
|
|
|
|
|
|
|
|
cmap = plt.get_cmap("jet")
|
|
cmap = plt.get_cmap("jet")
|
|
@@ -164,12 +112,23 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
|
|
|
|
+ if show_type=='all':
|
|
|
|
|
+ boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
|
|
+ colors="yellow", width=1)
|
|
|
|
|
+ keypoint_img=draw_keypoints(boxed_image,target['lines'],colors='red',width=3)
|
|
|
|
|
+ plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+
|
|
|
|
|
+ if show_type=='lines':
|
|
|
|
|
+ keypoint_img=draw_keypoints((self.default_transform(img) * 255).to(torch.uint8),target['lines'],colors='red',width=3)
|
|
|
|
|
+ plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
|
|
|
|
|
+ plt.show()
|
|
|
|
|
+
|
|
|
|
|
+ if show_type=='boxes':
|
|
|
|
|
+ boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
colors="yellow", width=1)
|
|
colors="yellow", width=1)
|
|
|
- keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
|
|
|
|
|
- plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
|
|
|
|
|
- plt.show()
|
|
|
|
|
-
|
|
|
|
|
|
|
+ plt.imshow(boxed_image.permute(1, 2, 0).numpy())
|
|
|
|
|
+ plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -177,33 +136,35 @@ class LineDataset(BaseDataset):
|
|
|
def show_img(self, img_path):
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
pass
|
|
|
|
|
|
|
|
-def get_boxes_lines(target):
|
|
|
|
|
- boxs = []
|
|
|
|
|
- lpre = target['wires']["lpre"].cpu().numpy()
|
|
|
|
|
- vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
|
|
- lines = lpre
|
|
|
|
|
- sline = np.ones(lpre.shape[0])
|
|
|
|
|
|
|
+def get_boxes_lines(objs,shape):
|
|
|
|
|
+ boxes = []
|
|
|
|
|
+ h,w=shape
|
|
|
line_point_pairs = []
|
|
line_point_pairs = []
|
|
|
|
|
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
|
|
- break
|
|
|
|
|
- # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
|
|
- line_point_pairs.append([a[1], a[0]])
|
|
|
|
|
- line_point_pairs.append([b[1], b[0]])
|
|
|
|
|
|
|
+ for obj in objs:
|
|
|
|
|
+ # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
|
|
+
|
|
|
|
|
+ # print(f"points:{obj['points']}")
|
|
|
|
|
+
|
|
|
|
|
+ a,b=obj['points'][0],obj['points'][1]
|
|
|
|
|
+
|
|
|
|
|
+ line_point_pairs.append(a)
|
|
|
|
|
+ line_point_pairs.append(b)
|
|
|
|
|
+
|
|
|
|
|
+ xmin = max(0, (min(a[0], b[0]) - 6))
|
|
|
|
|
+ xmax = min(w, (max(a[0], b[0]) + 6))
|
|
|
|
|
+ ymin = max(0, (min(a[1], b[1]) - 6))
|
|
|
|
|
+ ymax = min(h, (max(a[1], b[1]) + 6))
|
|
|
|
|
|
|
|
- xmin = max(0, (min(a[0], b[0]) - 6))
|
|
|
|
|
- xmax = min(511, (max(a[0], b[0]) + 6))
|
|
|
|
|
- ymin = max(0, (min(a[1], b[1]) - 6))
|
|
|
|
|
- ymax = min(511, (max(a[1], b[1]) + 6))
|
|
|
|
|
|
|
+ boxes.append([ xmin,ymin, xmax,ymax])
|
|
|
|
|
|
|
|
- boxs.append([ymin, xmin, ymax, xmax])
|
|
|
|
|
|
|
+ boxes=torch.tensor(boxes)
|
|
|
|
|
+ line_point_pairs=torch.tensor(line_point_pairs)
|
|
|
|
|
|
|
|
- return torch.tensor(boxs), torch.tensor(line_point_pairs)
|
|
|
|
|
|
|
+ # print(f'boxes:{boxes.shape},line_point_pairs:{line_point_pairs.shape}')
|
|
|
|
|
+ return boxes,line_point_pairs
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
|
- path=r"\\192.168.50.222/share/lm/Dataset_all"
|
|
|
|
|
- dataset= LineDataset(dataset_path=path, dataset_type='train')
|
|
|
|
|
- dataset.show(10)
|
|
|
|
|
|
|
+ path=r"\\192.168.50.222/share/rlq/datasets/0706_"
|
|
|
|
|
+ dataset= LineDataset(dataset_path=path, dataset_type='train',data_type='jpg')
|
|
|
|
|
+ dataset.show(1,show_type='lines')
|