Bläddra i källkod

dataset修改

xue50 7 månader sedan
förälder
incheckning
30f4f6ca9a

+ 8 - 8
models/line_detect/111.py

@@ -234,16 +234,16 @@ if __name__ == '__main__':
     # model = LineNet('line_net.yaml')
     model = linenet_resnet50_fpn().to(device)
     # model=linenet_resnet18_fpn()
-    # trainer = Trainer()
-    # trainer.train_cfg(model,cfg='./train.yaml')
-    # model.train_by_cfg(cfg='train.yaml')
-    # trainer = Trainer()
-    # trainer.train_cfg(model=model, cfg='train.yaml')
+    trainer = Trainer()
+    trainer.train_cfg(model,cfg='./train.yaml')
+    model.train_by_cfg(cfg='train.yaml')
+    trainer = Trainer()
+    trainer.train_cfg(model=model, cfg='train.yaml')
     #
     # pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
     # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
     # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
 
-    model = model.load_best_model(model, r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth")
-    img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
-    model.predict1(model, img_path, type=1, threshold=0, save_path=None, show=True)
+    # model = model.load_best_model(model, r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth")
+    # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
+    # model.predict1(model, img_path, type=1, threshold=0, save_path=None, show=True)

+ 7 - 24
models/line_detect/aaa.py

@@ -280,7 +280,7 @@ def predict(image_path):
 
     start_time = time.time()
 
-    model_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
+    model_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth"
     model = load_model(model_path)
 
     img_tensor,_ = preprocess_image(image_path)
@@ -289,9 +289,12 @@ def predict(image_path):
     # Ä£ÐÍÍÆÀí
     with torch.no_grad():
       predictions = model([img_tensor.to(device)])
-    # print(f'predictions[0]:{predictions[0]}')
+    print(f'predictions[0]:{predictions[1][0].shape}')   # 第2个是特征图 [1,256,128,128]
+    plt.imshow(predictions[1][0][2].cpu())
+    plt.show()
     # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
     # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328])
+    '''
 
     start_time1 = time.time()
     show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
@@ -310,27 +313,6 @@ def predict(image_path):
     diag = (512 ** 2 + 512 ** 2) ** 0.5
     nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
 
-
-    # t_start = time.time()
-    # filtered_pred = box_line_optimized_parallel(img_tensor,predictions)
-    # # print(f'匹配后:{filtered_pred}')
-    # print(f'匹配后 len:{filtered_pred[0]}')
-    # t_end = time.time()
-    # print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
-    # show_predict(img_tensor.permute(1, 2, 0).cpu().numpy(), filtered_pred, start_time1)
-
-
-    # # ºÏ²¢Í¼Ïñ
-    # combined_image_path = "combined_result.png"
-    # combine_images(
-    #     [output_path_boxandline, output_path_box, output_path_line],
-    #     titles=["Box and Line", "Box", "Line"],
-    #     output_path=combined_image_path
-    # )
-
-    # end_time = time.time()
-    # print(f'Total time: {end_time - start_time:.2f} seconds')
-
     # lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328])
     print(f'线段 len:{len(nlines)}')
     # print(f"Initial lines shape: {lines.shape}")
@@ -355,6 +337,7 @@ def predict(image_path):
     print(f"Final result[0] shape: {result[0].shape}")
 
     return result
+    '''
 
 
 
@@ -470,5 +453,5 @@ def show_predict(im, filtered_pred, t_start):
     return output_path
 
 if __name__ == "__main__":
-    lines = predict(r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png")
+    lines = predict(r"C:\Users\m2337\Desktop\p\140502.png")
     print(f'lines:{lines}')

+ 20 - 11
models/line_detect/dataset_LD.py

@@ -27,17 +27,20 @@ from models.dataset_tool import line_boxes, read_masks_from_txt_wire, read_masks
 
 from tools.presets import DetectionPresetTrain
 
+
 def line_boxes1(target):
     boxs = []
     lines = target.cpu().numpy() * 4
 
-    if len(lines) > 0 and not (lines[0] == 0).all():
+    if len(lines) > 0 :
         for i, ((a, b)) in enumerate(lines):
             if i > 0 and (lines[i] == lines[0]).all():
                 break
             # plt.plot([a[1], b[1]], [a[0], b[0]], c="red", linewidth=1)  # a[1], b[1]ÎÞÃ÷È·´óС
-            if a[-1]==0. and b[-1]==0.:
-                continue
+            # if a[-1]==0. and b[-1]==0.:
+            #     continue
+            # if a[:2].tolist() == [0., 0.] and b[:2].tolist() == [0., 0.]:
+            #     continue
 
             if a[1] > b[1]:
                 ymax = a[1] + 10
@@ -53,8 +56,10 @@ def line_boxes1(target):
                 xmax = b[0] + 10
             boxs.append([ymin, xmin, ymax, xmax])
 
-    # if boxs == []:
-    #     print(target)
+    # print(f'box:{boxs}')
+    if boxs == []:
+        print(f'box:{boxs}')
+        print(f'target:{target}')
 
     return torch.tensor(boxs)
 
@@ -146,14 +151,15 @@ class WirePointDataset(BaseDataset):
         target = {}
         # target["labels"] = torch.stack(labels)
 
-
         target["image_id"] = torch.tensor(item)
         # return wire_labels, target
         target["wires"] = wire_labels
-        target["boxes"] = line_boxes(target)
-        # target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
-        target["labels"]= torch.ones(len(target["boxes"]),dtype=torch.int64)
+        # target["boxes"] = line_boxes(target)
+        target["boxes"] = line_boxes1(torch.tensor(wire["line_pos_coords"]["content"]))
+        target["labels"] = torch.ones(len(target["boxes"]), dtype=torch.int64)
         # print(f'target["labels"]:{ target["labels"]}')
+        # if target["boxes"].shape == [0]:
+        #     print(f'box is null:{lbl_path}')
         # print(f'boxes:{target["boxes"].shape}')
         return target
 
@@ -187,7 +193,6 @@ class WirePointDataset(BaseDataset):
                         break
                     plt.scatter(j[1], j[0], c="red", s=2, zorder=100)  # 原 s=64
 
-
             img_path = os.path.join(self.img_path, self.imgs[idx])
             img = PIL.Image.open(img_path).convert('RGB')
             boxed_image = draw_bounding_boxes((self.default_transform(img) * 255).to(torch.uint8), target["boxes"],
@@ -211,6 +216,10 @@ class WirePointDataset(BaseDataset):
         # draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, save_path)
         draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts)
 
-
     def show_img(self, img_path):
         pass
+
+
+# dataset_train = WirePointDataset(r"\\192.168.50.222\share\lm\04\424-转分好的zjf", dataset_type='train')
+# for i in dataset_train:
+#     a = 1

+ 2 - 2
models/line_detect/predict2.py

@@ -470,9 +470,9 @@ if __name__ == '__main__':
     model = linenet_resnet50_fpn().to(device)
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练24轮结果.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
+    pt_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth"
     # pt_path = r"C:\Users\m2337\Downloads\best_e20.pth"
-    img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
+    img_path = r"C:\Users\m2337\Desktop\p\140502.png"
     predict(pt_path, model, img_path)
     t_end = time.time()
     print(f'predict used:{t_end - t_start}')

+ 1 - 1
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: logs/
 #  datadir: I:/datasets/4_23jiagonggongjian
-  datadir: I:/datasets/0322_suanzaisheng
+  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
   resume_from:
   num_workers: 8