|
@@ -27,17 +27,20 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
|
|
|
|
|
|
|
|
from tools.presets import DetectionPresetTrain
|
|
from tools.presets import DetectionPresetTrain
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def line_boxes1(target):
|
|
def line_boxes1(target):
|
|
|
boxs = []
|
|
boxs = []
|
|
|
lines = target.cpu().numpy() * 4
|
|
lines = target.cpu().numpy() * 4
|
|
|
|
|
|
|
|
- if len(lines) > 0 and not (lines[0] == 0).all():
|
|
|
|
|
|
|
+ if len(lines) > 0 :
|
|
|
for i, ((a, b)) in enumerate(lines):
|
|
for i, ((a, b)) in enumerate(lines):
|
|
|
if i > 0 and (lines[i] == lines[0]).all():
|
|
if i > 0 and (lines[i] == lines[0]).all():
|
|
|
break
|
|
break
|
|
|
# plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]ÎÞÃ÷È·´óС
|
|
# plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1) # a[1], b[1]ÎÞÃ÷È·´óС
|
|
|
- if a[-1]==0. and b[-1]==0.:
|
|
|
|
|
- continue
|
|
|
|
|
|
|
+ # if a[-1]==0. and b[-1]==0.:
|
|
|
|
|
+ # continue
|
|
|
|
|
+ # if a[:2].tolist() == [0., 0.] and b[:2].tolist() == [0., 0.]:
|
|
|
|
|
+ # continue
|
|
|
|
|
|
|
|
if a[1] > b[1]:
|
|
if a[1] > b[1]:
|
|
|
ymax = a[1] + 10
|
|
ymax = a[1] + 10
|
|
@@ -53,8 +56,10 @@ def line_boxes1(target):
|
|
|
xmax = b[0] + 10
|
|
xmax = b[0] + 10
|
|
|
boxs.append([ymin, xmin, ymax, xmax])
|
|
boxs.append([ymin, xmin, ymax, xmax])
|
|
|
|
|
|
|
|
- # if boxs == []:
|
|
|
|
|
- # print(target)
|
|
|
|
|
|
|
+ # print(f'box:{boxs}')
|
|
|
|
|
+ if boxs == []:
|
|
|
|
|
+ print(f'box:{boxs}')
|
|
|
|
|
+ print(f'target:{target}')
|
|
|
|
|
|
|
|
return torch.tensor(boxs)
|
|
return torch.tensor(boxs)
|
|
|
|
|
|
|
@@ -146,14 +151,15 @@ class WirePointDataset(BaseDataset):
|
|
|
target = {}
|
|
target = {}
|
|
|
# target["labels"] = torch.stack(labels)
|
|
# target["labels"] = torch.stack(labels)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
target["image_id"] = torch.tensor(item)
|
|
target["image_id"] = torch.tensor(item)
|
|
|
# return wire_labels, target
|
|
# return wire_labels, target
|
|
|
target["wires"] = wire_labels
|
|
target["wires"] = wire_labels
|
|
|
- target["boxes"] = line_boxes(target)
|
|
|
|
|
- # target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
|
|
|
|
|
- target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
|
|
|
|
|
|
|
+ # target["boxes"] = line_boxes(target)
|
|
|
|
|
+ target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
|
|
|
|
|
+ target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
|
|
|
# print(f'target["labels"]:{ target["labels"]}')
|
|
# print(f'target["labels"]:{ target["labels"]}')
|
|
|
|
|
+ # if target["boxes"].shape == [0]:
|
|
|
|
|
+ # print(f'box is null:{lbl_path}')
|
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
# print(f'boxes:{target["boxes"].shape}')
|
|
|
return target
|
|
return target
|
|
|
|
|
|
|
@@ -187,7 +193,6 @@ class WirePointDataset(BaseDataset):
|
|
|
break
|
|
break
|
|
|
plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
|
|
plt.scatter(j[1], j[0], c="red", s=2, zorder=100) # 原 s=64
|
|
|
|
|
|
|
|
-
|
|
|
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
img_path = os.path.join(self.img_path, self.imgs[idx])
|
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
img = PIL.Image.open(img_path).convert('RGB')
|
|
|
boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
|
|
@@ -211,6 +216,10 @@ class WirePointDataset(BaseDataset):
|
|
|
# draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
# draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
|
|
|
draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
def show_img(self, img_path):
|
|
def show_img(self, img_path):
|
|
|
pass
|
|
pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# dataset_train = WirePointDataset(r"\\192.168.50.222\share\lm\04\424-转分好的zjf", dataset_type='train')
|
|
|
|
|
+# for i in dataset_train:
|
|
|
|
|
+# a = 1
|