浏览代码

完善保存模型权重和日志功能

RenLiqiang 3 月之前
父节点
当前提交
84bdba3cef
共有 2 个文件被更改,包括 56 次插入37 次删除
  1. 19 1
      models/line_detect/line_predictor.py
  2. 37 36
      models/line_detect/trainer.py

+ 19 - 1
models/line_detect/line_predictor.py

@@ -26,7 +26,25 @@ def non_maximum_suppression(a):
     mask = (a == ap).float().clamp(min=0.0)
     return a * mask
 
-
+class Bottleneck1D(nn.Module):
+    def __init__(self, inplanes, outplanes):
+        super(Bottleneck1D, self).__init__()
+
+        planes = outplanes // 2
+        self.op = nn.Sequential(
+            nn.BatchNorm1d(inplanes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(inplanes, planes, kernel_size=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, planes, kernel_size=3, padding=1),
+            nn.BatchNorm1d(planes),
+            nn.ReLU(inplace=True),
+            nn.Conv1d(planes, outplanes, kernel_size=1),
+        )
+
+    def forward(self, x):
+        return x + self.op(x)
 
 class LineRCNNPredictor(nn.Module):
     def __init__(self, cfg):

+ 37 - 36
models/line_detect/trainer.py

@@ -1,5 +1,7 @@
+
 import os
 import time
+from datetime import datetime
 
 import torch
 from torch.utils.tensorboard import SummaryWriter
@@ -9,7 +11,7 @@ from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
 from models.line_detect.dataset_LD import WirePointDataset
 from models.line_detect.postprocess import box_line_, show_
-from utils.log_util import show_line, save_latest_model, save_best_model
+from utils.log_util import show_line, save_last_model, save_best_model
 from tools import utils
 
 
@@ -108,25 +110,31 @@ class Trainer(BaseTrainer):
             dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
         )
 
-        # model = linenet_resnet50_fpn().to(self.device)
+        train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+        wts_path = os.path.join(train_result_ptath, 'weights')
+        tb_path = os.path.join(train_result_ptath, 'logs')
+        writer = SummaryWriter(tb_path)
 
         optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
-        writer = SummaryWriter(kwargs['io']['logdir'])
+        # writer = SummaryWriter(kwargs['io']['logdir'])
         model.to(device)
 
-        # 加载权重
-        save_path = 'logs/pth/best_model.pth'
-        model, optimizer = self.load_best_model(model, optimizer, save_path, device)
 
-        logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
-        os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
-        latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
-        best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
+
+        # # 加载权重
+        # save_path = 'logs/pth/best_model.pth'
+        # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
+
+        # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
+        # os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
+        last_model_path = os.path.join(wts_path, 'last.pth')
+        best_model_path = os.path.join(wts_path, 'best.pth')
         global_step = 0
 
         for epoch in range(kwargs['optim']['max_epoch']):
             print(f"epoch:{epoch}")
             total_train_loss = 0.0
+
             model.train()
 
             for imgs, targets in data_loader_train:
@@ -135,44 +143,37 @@ class Trainer(BaseTrainer):
                 # print(f'imgs:{len(imgs)}')
                 # print(f'targets:{len(targets)}')
                 losses = model(imgs, targets)
-                # print(losses)
                 loss = _loss(losses)
+                total_train_loss += loss.item()
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
                 self.writer_loss(writer, losses, epoch)
 
+
+            avg_train_loss = total_train_loss / len(data_loader_train)
+            if epoch == 0:
+                best_loss = avg_train_loss;
+
+            writer.add_scalar('loss/train', avg_train_loss, epoch)
+
+
+            if os.path.exists(f'{wts_path}/last.pt'):
+                os.remove(f'{wts_path}/last.pt')
+            # torch.save(model.state_dict(), f'{wts_path}/last.pt')
+            save_last_model(model,last_model_path,epoch,optimizer)
+            best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
+
             model.eval()
             with torch.no_grad():
                 for batch_idx, (imgs, targets) in enumerate(data_loader_val):
                     t_start = time.time()
                     print(f'start to predict:{t_start}')
                     pred = model(self.move_to_device(imgs, self.device))
-                    # t_end = time.time()
-                    # print(f'predict used:{t_end - t_start}')
-                    # t_start=time.time()
-                    # print(f'start to box_line:{t_start}')
-                    # pred_ = box_line_(pred)  # 将box与line对应
-                    # t_end=time.time()
-                    # print(f'box_line_ used:{t_end-t_start}')
-                    # show_(imgs, pred_, epoch, writer)
+                    t_end = time.time()
+                    print(f'predict used:{t_end - t_start}')
                     if batch_idx == 0:
                         show_line(imgs[0], pred, epoch, writer)
                     break
-            avg_train_loss = total_train_loss / len(data_loader_train)
-            writer.add_scalar('loss/train', avg_train_loss, epoch)
-            best_loss = 10000
-            save_latest_model(
-                model,
-                latest_model_path,
-                epoch,
-                optimizer
-            )
-            best_loss = save_best_model(
-                model,
-                best_model_path,
-                epoch,
-                avg_train_loss,
-                best_loss,
-                optimizer
-            )
+
+