RenLiqiang 5 месяцев назад
Родитель
Сommit
feda62a418
2 измененных файлов с 8 добавлено и 18 удалено
  1. 7 17
      models/line_detect/line_dataset.py
  2. 1 1
      models/line_detect/trainer.py

+ 7 - 17
models/line_detect/line_dataset.py

@@ -126,7 +126,7 @@ class LineDataset(BaseDataset):
 
         # print(f'labels:{target["labels"]}')
         # target["boxes"] = line_boxes(target)
-        target["boxes"], lines = line_boxes(target)
+        target["boxes"], lines = get_boxes_lines(target)
         target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
         # keypoints=keypoints/512
         # visibility_flags = torch.ones((wire_labels["junc_coords"].shape[0], 1))
@@ -166,15 +166,13 @@ class LineDataset(BaseDataset):
     def show_img(self, img_path):
         pass
 
-def line_boxes(target):
+def get_boxes_lines(target):
     boxs = []
     lpre = target['wires']["lpre"].cpu().numpy()
     vecl_target = target['wires']["lpre_label"].cpu().numpy()
     lpre = lpre[vecl_target == 1]
-
     lines = lpre
     sline = np.ones(lpre.shape[0])
-
     line_point_pairs = []
 
     if len(lines) > 0 and not (lines[0] == 0).all():
@@ -182,22 +180,14 @@ def line_boxes(target):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]无明确大小
-
             line_point_pairs.append([a[1], a[0]])
             line_point_pairs.append([b[1], b[0]])
 
-            if a[1] > b[1]:
-                ymax = a[1] + 1
-                ymin = b[1] - 1
-            else:
-                ymin = a[1] - 1
-                ymax = b[1] + 1
-            if a[0] > b[0]:
-                xmax = a[0] + 1
-                xmin = b[0] - 1
-            else:
-                xmin = a[0] - 1
-                xmax = b[0] + 1
+            xmin = min(a[0], b[0]) - 1
+            xmax = max(a[0], b[0]) + 1
+            ymin = min(a[1], b[1]) - 1
+            ymax = max(a[1], b[1]) + 1
+
             boxs.append([ymin, xmin, ymax, xmax])
 
     return torch.tensor(boxs), torch.tensor(line_point_pairs)

+ 1 - 1
models/line_detect/trainer.py

@@ -267,7 +267,7 @@ class Trainer(BaseTrainer):
                 t_start = time.time()
                 print(f'start to predict:{t_start}')
                 result = model(self.move_to_device(imgs, self.device))
-                print(f'result:{result}')
+                # print(f'result:{result}')
                 t_end = time.time()
                 print(f'predict used:{t_end - t_start}')
                 self.writer_predict_result(img=imgs[0], result=result[0], epoch=epoch)