|
@@ -10,6 +10,10 @@ import random
|
|
|
import cv2
|
|
|
import PIL
|
|
|
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import matplotlib as mpl
|
|
|
+from torchvision.utils import draw_bounding_boxes
|
|
|
+
|
|
|
import numpy as np
|
|
|
import numpy.linalg as LA
|
|
|
import torch
|
|
@@ -51,8 +55,10 @@ class WirePointDataset(BaseDataset):
|
|
|
|
|
|
# print(f'img:{img}')
|
|
|
return img, target
|
|
|
+
|
|
|
def __len__(self):
|
|
|
return len(self.imgs)
|
|
|
+
|
|
|
def read_target(self, item, lbl_path, shape, extra=None):
|
|
|
# print(f'lbl_path:{lbl_path}')
|
|
|
with open(lbl_path, 'r') as file:
|
|
@@ -103,49 +109,105 @@ class WirePointDataset(BaseDataset):
|
|
|
elif self.target_type == 'pixel':
|
|
|
labels, masks = read_masks_from_pixels_wire(lbl_path, shape)
|
|
|
|
|
|
+ # print(torch.stack(masks).shape) # [线段数, 512, 512]
|
|
|
target = {}
|
|
|
- target["boxes"] = masks_to_boxes(torch.stack(masks))
|
|
|
+ # target["boxes"] = masks_to_boxes(torch.stack(masks))
|
|
|
+ # print(target["boxes"])
|
|
|
target["labels"] = torch.stack(labels)
|
|
|
target["masks"] = torch.stack(masks)
|
|
|
target["image_id"] = torch.tensor(item)
|
|
|
# return wire_labels, target
|
|
|
target["wires"] = wire_labels
|
|
|
+
|
|
|
+ boxs = []
|
|
|
+ junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
+ lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
+ vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
+ lpre = lpre[vecl_target == 1]
|
|
|
+
|
|
|
+ lines = lpre
|
|
|
+ sline = np.ones(lpre.shape[0])
|
|
|
+
|
|
|
+ if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+ for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+ break
|
|
|
+ # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
+ if a[1] > b[1]:
|
|
|
+ ymax = a[1] + 1
|
|
|
+ ymin = b[1] - 1
|
|
|
+ else:
|
|
|
+ ymin = a[1] - 1
|
|
|
+ ymax = b[1] + 1
|
|
|
+ if a[0] > b[0]:
|
|
|
+ xmax = a[0] + 1
|
|
|
+ xmin = b[0] - 1
|
|
|
+ else:
|
|
|
+ xmin = a[0] - 1
|
|
|
+ xmax = b[0] + 1
|
|
|
+ boxs.append([ymin, xmin, ymax, xmax])
|
|
|
+ # plt.Rectangle([a[1] - 1, b[1] + 1], [a[0] + 1, b[0] - 1], c="g", linewidth=1)
|
|
|
+ target["line_boxes"] = torch.tensor(boxs)
|
|
|
+
|
|
|
return target
|
|
|
|
|
|
def show(self, idx):
|
|
|
- img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
- lbl_path = os.path.join(self.lbl_path, self.imgs[idx][:-3] + 'json')
|
|
|
+ image, target = self.__getitem__(idx)
|
|
|
+
|
|
|
+ cmap = plt.get_cmap("jet")
|
|
|
+ norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
|
|
|
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
|
|
|
+ sm.set_array([])
|
|
|
+
|
|
|
+ def imshow(im):
|
|
|
+ plt.close()
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.imshow(im)
|
|
|
+ plt.colorbar(sm, fraction=0.046)
|
|
|
+ plt.xlim([0, im.shape[0]])
|
|
|
+ plt.ylim([im.shape[0], 0])
|
|
|
+
|
|
|
+ def draw_vecl(lines, sline, juncs, junts, fn=None):
|
|
|
+ img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
+ imshow(io.imread(img_path))
|
|
|
+ if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
+ for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
+ if i > 0 and (lines[i] == lines[0]).all():
|
|
|
+ break
|
|
|
+ plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
+ if not (juncs[0] == 0).all():
|
|
|
+ for i, j in enumerate(juncs):
|
|
|
+ if i > 0 and (i == juncs[0]).all():
|
|
|
+ break
|
|
|
+ plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
|
|
|
+
|
|
|
+
|
|
|
+ img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
+ img = PIL.Image.open(img_path).convert('RGB')
|
|
|
+ boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["line_boxes"],
|
|
|
+ colors="yellow", width=1)
|
|
|
+ plt.imshow(boxed_image.permute(1, 2, 0).numpy())
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ plt.show()
|
|
|
+ if fn != None:
|
|
|
+ plt.savefig(fn)
|
|
|
+
|
|
|
+ junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
+ jtyp = target['wires']['jtyp'].cpu().numpy()
|
|
|
+ juncs = junc[jtyp == 0]
|
|
|
+ junts = junc[jtyp == 1]
|
|
|
+
|
|
|
+ lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
+ vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
+ lpre = lpre[vecl_target == 1]
|
|
|
+
|
|
|
+ # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
+ draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
+
|
|
|
+
|
|
|
+ def show_img(self, img_path):
|
|
|
+ pass
|
|
|
|
|
|
- with open(lbl_path, 'r') as file:
|
|
|
- lable_all = json.load(file)
|
|
|
|
|
|
- # 可视化图像和标注
|
|
|
- image = cv2.imread(img_path) # [H,W,3] # 默认为BGR格式
|
|
|
- # print(image.shape)
|
|
|
- # 绘制每个标注的多边形
|
|
|
- # for ann in lable_all["segmentations"]:
|
|
|
- # segmentation = [[x * 512 for x in ann['data']]]
|
|
|
- # # segmentation = [ann['data']]
|
|
|
- # # for i in range(len(ann['data'])):
|
|
|
- # # if i % 2 == 0:
|
|
|
- # # segmentation[0][i] *= image.shape[0]
|
|
|
- # # else:
|
|
|
- # # segmentation[0][i] *= image.shape[0]
|
|
|
- #
|
|
|
- # # if isinstance(segmentation, list):
|
|
|
- # # for seg in segmentation:
|
|
|
- # # poly = np.array(seg).reshape((-1, 2)).astype(int)
|
|
|
- # # cv2.polylines(image, [poly], isClosed=True, color=(0, 255, 0), thickness=2)
|
|
|
- # # cv2.fillPoly(image, [poly], color=(0, 255, 0))
|
|
|
-
|
|
|
-
|
|
|
- #
|
|
|
- # # 显示图像
|
|
|
- # cv2.namedWindow('Image with Segmentations', cv2.WINDOW_NORMAL)
|
|
|
- # cv2.imshow('Image with Segmentations', image)
|
|
|
- # cv2.waitKey(0)
|
|
|
- # cv2.destroyAllWindows()
|
|
|
-
|
|
|
- def show_img(self,img_path):
|
|
|
- pass
|
|
|
|