فهرست منبع

原分辨率特征图提取线段在4080取得初步效果

lstrlq 5 ماه پیش
والد
کامیت
77a2228b4d
3فایلهای تغییر یافته به همراه5 افزوده شده و 5 حذف شده
  1. 1 1
      models/line_detect/line_detect.py
  2. 2 2
      models/line_detect/loi_heads.py
  3. 2 2
      models/line_detect/train.yaml

+ 1 - 1
models/line_detect/line_detect.py

@@ -168,7 +168,7 @@ class LineDetect(BaseDetectionNet):
 
         if line_head is None:
             keypoint_layers = tuple(1 for _ in range(8))
-            line_head = LineHeads(16, keypoint_layers)
+            line_head = LineHeads(8, keypoint_layers)
 
         if line_predictor is None:
             keypoint_dim_reduced = 512  # == keypoint_layers[-1]

+ 2 - 2
models/line_detect/loi_heads.py

@@ -1046,8 +1046,8 @@ class RoIHeads(nn.Module):
         self.keypoint_predictor = keypoint_predictor
 
         self.channel_compress = nn.Sequential(
-            nn.Conv2d(256, 16, kernel_size=1),
-            nn.BatchNorm2d(16),
+            nn.Conv2d(256, 8, kernel_size=1),
+            nn.BatchNorm2d(8),
             nn.ReLU(inplace=True)
         )
 

+ 2 - 2
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: train_results
-  datadir: \\192.168.50.222/share/zyh/202507/a_dataset
+  datadir: /data/share/zyh/202507/a_dataset
   data_type: rgb
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
@@ -11,7 +11,7 @@ io:
 train_params:
   resume_from:
   num_workers: 8
-  batch_size: 1
+  batch_size: 2
   max_epoch: 80000
   augmentation: True
   optim: