RenLiqiang 7 mēneši atpakaļ
vecāks
revīzija
34d748b2dc
2 mainītis faili ar 17 papildinājumiem un 14 dzēšanām
  1. 11 9
      models/line_detect/111.py
  2. 6 5
      models/line_detect/roi_heads.py

+ 11 - 9
models/line_detect/111.py

@@ -231,15 +231,17 @@ if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
     model=linenet_resnet50_fpn().to(device)
-    model=linenet_resnet18_fpn()
-    trainer = Trainer()
-    trainer.train_cfg(model,cfg='./train.yaml')
-    model.train_by_cfg(cfg='train.yaml')
-    trainer = Trainer()
-    trainer.train_cfg(model=model, cfg='train.yaml')
-    # pt_path = r"C:\Users\m2337\Downloads\best_lmap代替x,训练75轮.pth"
-    # img_path = r"C:\Users\m2337\Desktop\p\新建文件夹\2025-03-25-16-10-00_SaveLeftImage.png"
-    # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
+    # model=linenet_resnet18_fpn()
+    # trainer = Trainer()
+    # trainer.train_cfg(model,cfg='./train.yaml')
+    # model.train_by_cfg(cfg='train.yaml')
+    # trainer = Trainer()
+    # trainer.train_cfg(model=model, cfg='train.yaml')
+    #
+    pt_path = r"E:\projects\tmp\MultiVisionModels\models\line_detect\train_results\20250424_162124\weights\best.pth"
+    img_path = r"I:\datasets\4_23jiagonggongjian\images\val\2025-04-23-08-52-00_SaveRightImage.png"
+
+    model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
 
 
 

+ 6 - 5
models/line_detect/roi_heads.py

@@ -1050,15 +1050,16 @@ class RoIHeads(nn.Module):
                     }
                 )
 
-        features_lcnn = features['0']
+        line_features = features['0']
         if self.has_line():
             # print('has line_head')
             # outputs = self.line_head(features_lcnn)
-            outputs = features_lcnn[:, 0:5, :, :]
+            # outputs = line_features[:, 0:5, :, :]
+
 
             loss_weight = {'junc_map': 8.0, 'line_map': 0.5, 'junc_offset': 0.25, 'lpos': 1, 'lneg': 1}
             x, y, idx, jcs, n_batch, ps, n_out_line, n_out_junc = self.line_predictor(
-                inputs=outputs, features=features_lcnn, targets=targets)
+                inputs=line_features, features=line_features, targets=targets)
 
             # # line_loss(multitasklearner)
             # if self.training:
@@ -1071,12 +1072,12 @@ class RoIHeads(nn.Module):
             #                                        loss_weight, mode_train=False)
 
             if self.training:
-                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, outputs, x, y, idx, loss_weight)
+                rcnn_loss_wirepoint = wirepoint_head_line_loss(targets, line_features, x, y, idx, loss_weight)
                 loss_wirepoint = {"loss_wirepoint": rcnn_loss_wirepoint}
             else:
 
                 pred = wirepoint_inference(x, idx, jcs, n_batch, ps, n_out_line, n_out_junc)
-                result.append(outputs)
+                result.append(line_features)
                 result.append(pred)
                 loss_wirepoint = {}
             losses.update(loss_wirepoint)