Ver Fonte

修改box损失过大问题

lstrlq há 6 meses atrás
pai
commit
a81df9a4d0
2 ficheiros alterados com 5 adições e 10 exclusões
  1. 4 9
      models/line_detect/dataset_LD.py
  2. 1 1
      models/line_detect/train_demo.py

+ 4 - 9
models/line_detect/dataset_LD.py

@@ -1,4 +1,3 @@
-# ??roi_head??????????????
 from torch.utils.data.dataset import T_co
 
 from models.base.base_dataset import BaseDataset
@@ -29,7 +28,7 @@ from tools.presets import DetectionPresetTrain
 
 def line_boxes1(target):
     boxs = []
-    lines = target.cpu().numpy() * 4
+    lines = target.cpu().numpy()
 
     if len(lines) > 0 and not (lines[0] == 0).all():
         for i, ((a, b)) in enumerate(lines):
@@ -111,7 +110,7 @@ class WirePointDataset(BaseDataset):
         ldir = lpre[:, 0, :2] - lpre[:, 1, :2]
         ldir /= np.clip(LA.norm(ldir, axis=1, keepdims=True), 1e-6, None)
         feat = [
-            lpre[:, :, :2].reshape(-1, 4) / 128 * use_cood,
+            lpre[:, :, :2].reshape(-1, 4) / 512 * use_cood,
             ldir * use_slop,
             lpre[:, :, 2],
         ]
@@ -199,12 +198,12 @@ class WirePointDataset(BaseDataset):
             if fn != None:
                 plt.savefig(fn)
 
-        junc = target['wires']['junc_coords'].cpu().numpy() * 4
+        junc = target['wires']['junc_coords'].cpu().numpy()
         jtyp = target['wires']['jtyp'].cpu().numpy()
         juncs = junc[jtyp == 0]
         junts = junc[jtyp == 1]
 
-        lpre = target['wires']["lpre"].cpu().numpy() * 4
+        lpre = target['wires']["lpre"].cpu().numpy()
         vecl_target = target['wires']["lpre_label"].cpu().numpy()
         lpre = lpre[vecl_target == 1]
 
@@ -215,7 +214,3 @@ class WirePointDataset(BaseDataset):
     def show_img(self, img_path):
         pass
 
-
-# dataset_train = WirePointDataset("/data/lm/dataset/0424_", dataset_type='val')
-# for i in dataset_train:
-#     a = 1

+ 1 - 1
models/line_detect/train_demo.py

@@ -12,7 +12,7 @@ if __name__ == '__main__':
     # model = linenet_resnet18_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     model=linenet_newresnet50fpn()
-    model.load_best_model('train_results/20250622_135121/weights/best_val.pth')
+    model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.start_train(cfg='train.yaml')