|
|
@@ -126,7 +126,7 @@ class LineDataset(BaseDataset):
|
|
|
|
|
|
# print(f'labels:{target["labels"]}')
|
|
|
# target["boxes"] = line_boxes(target)
|
|
|
- target["boxes"], lines = line_boxes(target)
|
|
|
+ target["boxes"], lines = get_boxes_lines(target)
|
|
|
target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
|
# keypoints=keypoints/512
|
|
|
# visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
|
|
|
@@ -166,15 +166,13 @@ class LineDataset(BaseDataset):
|
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
|
|
|
|
-def line_boxes(target):
|
|
|
+def get_boxes_lines(target):
|
|
|
boxs = []
|
|
|
lpre = target['wires']["lpre"].cpu().numpy()
|
|
|
vecl_target = target['wires']["lpre_label"].cpu().numpy()
|
|
|
lpre = lpre[vecl_target == 1]
|
|
|
-
|
|
|
lines = lpre
|
|
|
sline = np.ones(lpre.shape[0])
|
|
|
-
|
|
|
line_point_pairs = []
|
|
|
|
|
|
if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
@@ -182,22 +180,14 @@ def line_boxes(target):
|
|
|
if i > 0 and (lines[i] == lines[0]).all():
|
|
|
break
|
|
|
# plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]无明确大小
|
|
|
-
|
|
|
line_point_pairs.append([a[1], a[0]])
|
|
|
line_point_pairs.append([b[1], b[0]])
|
|
|
|
|
|
- if a[1] > b[1]:
|
|
|
- ymax = a[1] + 1
|
|
|
- ymin = b[1] - 1
|
|
|
- else:
|
|
|
- ymin = a[1] - 1
|
|
|
- ymax = b[1] + 1
|
|
|
- if a[0] > b[0]:
|
|
|
- xmax = a[0] + 1
|
|
|
- xmin = b[0] - 1
|
|
|
- else:
|
|
|
- xmin = a[0] - 1
|
|
|
- xmax = b[0] + 1
|
|
|
+ xmin = min(a[0], b[0]) - 1
|
|
|
+ xmax = max(a[0], b[0]) + 1
|
|
|
+ ymin = min(a[1], b[1]) - 1
|
|
|
+ ymax = max(a[1], b[1]) + 1
|
|
|
+
|
|
|
boxs.append([ymin, xmin, ymax, xmax])
|
|
|
|
|
|
return torch.tensor(boxs), torch.tensor(line_point_pairs)
|