Просмотр исходного кода

修复train_demo 训练导包错乱的问题

lstrlq 5 месяцев назад
Родитель
Сommit
d22cf19e5c

+ 1 - 1
models/line_detect/line_detect.py

@@ -366,7 +366,7 @@ def linedetect_newresnet18fpn(
 
 
 
-def lineDetect_resnet18_fpn(
+def linedetect_resnet18_fpn(
         *,
         num_classes: Optional[int] = None,
         num_points: Optional[int] = None,

+ 5 - 1
models/line_detect/roi_heads.py

@@ -830,7 +830,7 @@ class RoIHeads(nn.Module):
 
         result: List[Dict[str, torch.Tensor]] = []
         losses = {}
-
+        # _, C, H, W = features['0'].shape  # 忽略 batch_size,因为我们只关心 C, H, W
         if self.training:
             if labels is None:
                 raise ValueError("labels cannot be None")
@@ -862,6 +862,10 @@ class RoIHeads(nn.Module):
             line_proposals = [p["boxes"] for p in result]
             print(f'line_proposals:{len(line_proposals)}')
 
+            # if line_proposals is None or len(line_proposals) == 0:
+            #     # 返回空特征或者跳过该部分计算
+            #     return torch.empty(0, C, H, W).to(features['0'].device)
+
             if self.training:
                 # during training, only focus on positive boxes
                 num_images = len(proposals)

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222\share\rlq\datasets\250612
+  datadir: /data/share/rlq/datasets/250612
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 4 - 3
models/line_detect/train_demo.py

@@ -1,8 +1,8 @@
 import torch
 
-from models.line_detect.line_detect import linedetect_newresnet18fpn
-from models.line_net.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, linenet_newresnet50fpn, \
-    get_line_net_convnext_fpn, linenet_newresnet18fpn
+from models.line_detect.line_detect import linedetect_newresnet18fpn, linedetect_resnet50_fpn, linedetect_resnet18_fpn
+
+
 from models.line_net.trainer import Trainer
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -15,6 +15,7 @@ if __name__ == '__main__':
     # model=linenet_newresnet50fpn()
     # model = lineDetect_resnet18_fpn()
 
+    # model=linedetect_resnet18_fpn()
     model=linedetect_newresnet18fpn()
 
     model.start_train(cfg='train.yaml')