Browse Source

修改线性层参数

lstrlq 5 months ago
parent
commit
228ee2a341
2 changed files with 4 additions and 4 deletions
  1. 3 3
      models/line_detect/line_predictor.py
  2. 1 1
      models/line_detect/train_demo.py

+ 3 - 3
models/line_detect/line_predictor.py

@@ -106,11 +106,11 @@ class LineRCNNPredictor(nn.Module):
         else:
             self.pooling = nn.MaxPool1d(scale_factor, scale_factor)
             self.fc2 = nn.Sequential(
-                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc),
+                nn.Linear(self.dim_loi * self.n_pts1 + FEATURE_DIM, self.dim_fc*16),
                 nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, self.dim_fc),
+                nn.Linear(self.dim_fc*16, self.dim_fc*8),
                 nn.ReLU(inplace=True),
-                nn.Linear(self.dim_fc, 1),
+                nn.Linear(self.dim_fc*8, 1),
             )
         self.loss = nn.BCEWithLogitsLoss(reduction="none")
 

+ 1 - 1
models/line_detect/train_demo.py

@@ -12,7 +12,7 @@ if __name__ == '__main__':
     # model = linenet_resnet18_fpn()
     # model=get_line_net_convnext_fpn(num_classes=2).to(device)
     model=linenet_newresnet50fpn()
-    model.load_best_model('train_results/20250622_140412/weights/best_val.pth')
+    # model.load_best_model('train_results/20250622_143530/weights/best_val.pth')
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
     model.start_train(cfg='train.yaml')