Parcourir la source

添加学习率调度器

lstrlq il y a 5 mois
Parent
commit
cd0ab8613e

+ 4 - 3
models/line_detect/line_dataset.py

@@ -53,10 +53,11 @@ class LineDataset(BaseDataset):
         if self.data_type == 'tiff':
             lbl_path = os.path.join(self.lbl_path, self.imgs[index][:-4] + 'json')
             # img = imageio.v3.imread(img_path).reshape(512, 512, 1)
-            img = imageio.v3.imread(img_path)[:, :, -1].reshape(512, 512, 1)
-            img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
-            img_3channel[:, :, 2] = img[:, :, 0]
+            img = imageio.v3.imread(img_path)[:, :, :3]
+            # img_3channel = np.zeros((512, 512, 3), dtype=img.dtype)
+            # img_3channel[:, :, 2] = img[:, :, 0]
 
+            img_3channel=img
             w, h = img.shape[:2]
             img = torch.from_numpy(img_3channel).permute(2, 0, 1)
         else:

+ 2 - 2
models/line_detect/train.yaml

@@ -1,7 +1,7 @@
 io:
   logdir: train_results
-  datadir: /data/share/rlq/datasets/250705
-  data_type: tiff
+  datadir: /data/share/zyh/202507/a_dataset
+  data_type: jpg
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
 

+ 1 - 0
models/line_detect/train_demo.py

@@ -17,5 +17,6 @@ if __name__ == '__main__':
 
     # model=linedetect_resnet18_fpn()
     model=linedetect_newresnet18fpn()
+    model.load_weights(r'/home/admin/projects/MultiVisionModels/models/line_detect/train_results/20250706_150832/weights/best_val.pth')
 
     model.start_train(cfg='train.yaml')

+ 12 - 2
models/line_detect/trainer.py

@@ -5,6 +5,7 @@ from datetime import datetime
 import numpy as np
 import torch
 from matplotlib import pyplot as plt
+from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
 from torch.utils.tensorboard import SummaryWriter
 
 from libs.vision_libs.utils import draw_bounding_boxes, draw_keypoints
@@ -263,17 +264,23 @@ class Trainer(BaseTrainer):
 
         optimizer = torch.optim.Adam(
             filter(lambda p: p.requires_grad, model.parameters()),
-            lr=kwargs['train_params']['optim']['lr']
+            lr=kwargs['train_params']['optim']['lr'],
+            weight_decay=kwargs['train_params']['optim']['weight_decay'],
+
         )
+        # scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
+        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=30)
 
         for epoch in range(self.max_epoch):
             print(f"train epoch:{epoch}")
 
-            model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
 
+            model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
+            scheduler.step(epoch_train_loss)
             # ========== Validation ==========
             with torch.no_grad():
                 model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val')
+                scheduler.step(epoch_val_loss)
 
             if epoch==0:
                 best_train_loss = epoch_train_loss
@@ -286,6 +293,9 @@ class Trainer(BaseTrainer):
             best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
                                                  optimizer)
 
+
+
+
     def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
         if phase == 'train':
             model.train()