|
@@ -24,6 +24,7 @@ from torch.utils.data.dataloader import default_collate
|
|
|
import matplotlib.pyplot as plt
|
|
|
from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
|
|
|
+
|
|
|
def validate_keypoints(keypoints, image_width, image_height):
|
|
|
for kp in keypoints:
|
|
|
x, y, v = kp
|
|
@@ -123,14 +124,14 @@ class KeypointDataset(BaseDataset):
|
|
|
|
|
|
target["labels"] = torch.stack(labels)
|
|
|
# print(f'labels:{target["labels"]}')
|
|
|
- target["boxes"] = line_boxes(target)
|
|
|
+ target["boxes"], keypoints = line_boxes(target)
|
|
|
# visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
|
|
|
|
|
|
- keypoints= wire_labels["junc_coords"]
|
|
|
- keypoints[:,2]=2
|
|
|
- # keypoints[:,0]=keypoints[:,0]/shape[0]
|
|
|
- # keypoints[:, 1] = keypoints[:, 1] / shape[1]
|
|
|
- target["keypoints"]=keypoints
|
|
|
+ # keypoints= wire_labels["junc_coords"]
|
|
|
+ a = torch.full((keypoints.shape[0],), 2).unsqueeze(1)
|
|
|
+ keypoints = torch.cat((keypoints, a), dim=1)
|
|
|
+ target["keypoints"] = keypoints.to(torch.float32)
|
|
|
+ print(f'boxes:{target["boxes"].shape}')
|
|
|
# 在 __getitem__ 方法中调用此函数
|
|
|
validate_keypoints(keypoints, shape[0], shape[1])
|
|
|
print(f'keypoints:{target["keypoints"].shape}')
|
|
@@ -166,7 +167,6 @@ class KeypointDataset(BaseDataset):
|
|
|
break
|
|
|
plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
|
|
|
|
|
|
-
|
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
@@ -190,13 +190,11 @@ class KeypointDataset(BaseDataset):
|
|
|
# draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
|
|
|
-
|
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
|
|
|
|
|
|
|
-
|
|
|
if __name__ == '__main__':
|
|
|
- path=r"I:\wirenet_dateset"
|
|
|
- dataset= KeypointDataset(dataset_path=path, dataset_type='train')
|
|
|
+ path = r"I:\wirenet_dateset"
|
|
|
+ dataset = KeypointDataset(dataset_path=path, dataset_type='train')
|
|
|
dataset.show(0)
|