|
|
@@ -1,5 +1,6 @@
|
|
|
from torch.utils.data.dataset import T_co
|
|
|
|
|
|
+from libs.vision_libs.utils import draw_keypoints
|
|
|
from models.base.base_dataset import BaseDataset
|
|
|
|
|
|
import glob
|
|
|
@@ -22,7 +23,7 @@ from torch.utils.data import Dataset
|
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
-from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
+from models.dataset_tool import read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
|
|
|
def validate_keypoints(keypoints, image_width, image_height):
|
|
|
for kp in keypoints:
|
|
|
@@ -121,10 +122,12 @@ class KeypointDataset(BaseDataset):
|
|
|
# return wire_labels, target
|
|
|
target["wires"] = wire_labels
|
|
|
|
|
|
- target["labels"] = torch.stack(labels)
|
|
|
+ # target["labels"] = torch.stack(labels)
|
|
|
+
|
|
|
# print(f'labels:{target["labels"]}')
|
|
|
# target["boxes"] = line_boxes(target)
|
|
|
target["boxes"], keypoints = line_boxes(target)
|
|
|
+ target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
|
# keypoints=keypoints/512
|
|
|
# visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
|
|
|
|
|
|
@@ -147,59 +150,61 @@ class KeypointDataset(BaseDataset):
|
|
|
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
sm.set_array([])
|
|
|
|
|
|
- def imshow(im):
|
|
|
- plt.close()
|
|
|
- plt.tight_layout()
|
|
|
- plt.imshow(im)
|
|
|
- plt.colorbar(sm, fraction=0.046)
|
|
|
- plt.xlim([0, im.shape[0]])
|
|
|
- plt.ylim([im.shape[0], 0])
|
|
|
-
|
|
|
- def draw_vecl(lines, sline, juncs, junts, fn=None):
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- imshow(io.imread(img_path))
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
- break
|
|
|
- plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
- if not (juncs[0] == 0).all():
|
|
|
- for i, j in enumerate(juncs):
|
|
|
- if i > 0 and (i == juncs[0]).all():
|
|
|
- break
|
|
|
- plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
|
|
|
-
|
|
|
-
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
- colors="yellow", width=1)
|
|
|
- plt.imshow(boxed_image.permute(1, 2, 0).numpy())
|
|
|
- plt.show()
|
|
|
|
|
|
- plt.show()
|
|
|
- if fn != None:
|
|
|
- plt.savefig(fn)
|
|
|
|
|
|
- junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
- jtyp = target['wires']['jtyp'].cpu().numpy()
|
|
|
- juncs = junc[jtyp == 0]
|
|
|
- junts = junc[jtyp == 1]
|
|
|
|
|
|
- lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
- vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
+ img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
+ img = PIL.Image.open(img_path).convert('RGB')
|
|
|
+ boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
+ colors="yellow", width=1)
|
|
|
+ keypoint_img=draw_keypoints(boxed_image,target['keypoints'],colors='red',width=3)
|
|
|
+ plt.imshow(keypoint_img.permute(1, 2, 0).numpy())
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+
|
|
|
|
|
|
- # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
- draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
|
|
|
|
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
|
|
|
|
-
|
|
|
+def line_boxes(target):
|
|
|
+ boxs = []
|
|
|
+ lpre = target['wires']["lpre"].cpu().numpy()
|
|
|
+ vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
+ lpre = lpre[vecl_target == 1]
|
|
|
+
|
|
|
+ lines = lpre
|
|
|
+ sline = np.ones(lpre.shape[0])
|
|
|
+
|
|
|
+ keypoints = []
|
|
|
+
|
|
|
+ if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+ for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+ break
|
|
|
+ # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
+
|
|
|
+ keypoints.append([a[1], a[0]])
|
|
|
+ keypoints.append([b[1], b[0]])
|
|
|
+
|
|
|
+ if a[1] > b[1]:
|
|
|
+ ymax = a[1] + 1
|
|
|
+ ymin = b[1] - 1
|
|
|
+ else:
|
|
|
+ ymin = a[1] - 1
|
|
|
+ ymax = b[1] + 1
|
|
|
+ if a[0] > b[0]:
|
|
|
+ xmax = a[0] + 1
|
|
|
+ xmin = b[0] - 1
|
|
|
+ else:
|
|
|
+ xmin = a[0] - 1
|
|
|
+ xmax = b[0] + 1
|
|
|
+ boxs.append([ymin, xmin, ymax, xmax])
|
|
|
+
|
|
|
+ return torch.tensor(boxs), torch.tensor(keypoints)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
- path=r"I:\datasets\wirenet_1000"
|
|
|
+ path=r"\\192.168.50.222/share/lm/Dataset_all"
|
|
|
dataset= KeypointDataset(dataset_path=path, dataset_type='train')
|
|
|
- dataset.show(7)
|
|
|
+ dataset.show(10)
|