|
@@ -22,7 +22,7 @@ from torch.utils.data import Dataset
|
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
-from models.dataset_tool import masks_to_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
+from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks_from_pixels_wire, adjacency_matrix
|
|
|
|
|
|
|
|
|
class WirePointDataset(BaseDataset):
|
|
@@ -111,44 +111,12 @@ class WirePointDataset(BaseDataset):
|
|
|
|
|
|
# print(torch.stack(masks).shape) # [线段数, 512, 512]
|
|
|
target = {}
|
|
|
- # target["boxes"] = masks_to_boxes(torch.stack(masks))
|
|
|
- # print(target["boxes"])
|
|
|
target["labels"] = torch.stack(labels)
|
|
|
target["masks"] = torch.stack(masks)
|
|
|
target["image_id"] = torch.tensor(item)
|
|
|
# return wire_labels, target
|
|
|
target["wires"] = wire_labels
|
|
|
-
|
|
|
- boxs = []
|
|
|
- junc = target['wires']['junc_coords'].cpu().numpy() * 4
|
|
|
- lpre = target['wires']["lpre"].cpu().numpy() * 4
|
|
|
- vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
- lpre = lpre[vecl_target == 1]
|
|
|
-
|
|
|
- lines = lpre
|
|
|
- sline = np.ones(lpre.shape[0])
|
|
|
-
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
- for i, ((a, b), s) in enumerate(zip(lines, sline)):
|
|
|
- if i > 0 and (lines[i] == lines[0]).all():
|
|
|
- break
|
|
|
- # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
- if a[1] > b[1]:
|
|
|
- ymax = a[1] + 1
|
|
|
- ymin = b[1] - 1
|
|
|
- else:
|
|
|
- ymin = a[1] - 1
|
|
|
- ymax = b[1] + 1
|
|
|
- if a[0] > b[0]:
|
|
|
- xmax = a[0] + 1
|
|
|
- xmin = b[0] - 1
|
|
|
- else:
|
|
|
- xmin = a[0] - 1
|
|
|
- xmax = b[0] + 1
|
|
|
- boxs.append([ymin, xmin, ymax, xmax])
|
|
|
- # plt.Rectangle([a[1] - 1, b[1] + 1], [a[0] + 1, b[0] - 1], c="g", linewidth=1)
|
|
|
- target["line_boxes"] = torch.tensor(boxs)
|
|
|
-
|
|
|
+ target["boxes"] = line_boxes(target)
|
|
|
return target
|
|
|
|
|
|
def show(self, idx):
|
|
@@ -184,7 +152,7 @@ class WirePointDataset(BaseDataset):
|
|
|
|
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
- boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["line_boxes"],
|
|
|
+ boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
|
colors="yellow", width=1)
|
|
|
plt.imshow(boxed_image.permute(1, 2, 0).numpy())
|
|
|
plt.show()
|