Browse Source

优化predictor从line_map进行路径插值采样特征

RenLiqiang 8 months ago
parent
commit
a5c2d0ce56

+ 11 - 6
models/line_detect/line_predictor.py

@@ -50,7 +50,7 @@ class LineRCNNPredictor(nn.Module):
     def __init__(self,n_pts0 = 32,
                  n_pts1 = 8,
                  n_stc_posl =300,
-                 dim_loi = 128,
+                 dim_loi = 1,
                  use_conv = 0,
                  dim_fc = 1024,
                  n_out_line = 2500,
@@ -190,11 +190,13 @@ class LineRCNNPredictor(nn.Module):
         n_jtyp = T["junc_map"].shape[1]
         offset = self.head_off
         result = {}
+        print(f' wires_targets len:{len(wires_targets)}')
         for stack, output in enumerate([inputs]):
             output = output.transpose(0, 1).reshape([-1, batch, row, col]).contiguous()
             # print(f"Stack {stack} output shape: {output.shape}")  # 打印每层的输出形状
             jmap = output[0: offset[0]].reshape(n_jtyp, 2, batch, row, col)
-            lmap = output[offset[0]: offset[1]].squeeze(0)
+            # lmap = output[offset[0]: offset[1]].squeeze(0)
+            lmap = output[offset[0]: offset[1]]
             joff = output[offset[1]: offset[2]].reshape(n_jtyp, 2, batch, row, col)
 
             if stack == 0:
@@ -208,12 +210,15 @@ class LineRCNNPredictor(nn.Module):
                 # visualize_feature_map(joff[0, 0], title=f"joff - Stack {stack}")
 
         h = result["preds"]
-        # print(f'features shape:{features.shape}')
-        x = self.fc1(features)
-
-        # print(f'x:{x.shape}')
+        print(f'features shape:{features.shape}')
+        print(f'inputs shape :{inputs.shape}')
+        # x = self.fc1(features)
+        x = inputs[:,2:3,:,:].sigmoid()
+        print(f'x:{x.shape}')
 
         n_batch, n_channel, row, col = x.shape
+        # n_batch, n_channel, row, col = x.shape
+
 
         # print(f'n_batch:{n_batch}, n_channel:{n_channel}, row:{row}, col:{col}')
 

+ 239 - 0
models/line_detect/test_train2.py

@@ -0,0 +1,239 @@
+import os
+import time
+from datetime import datetime
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from models.base.base_model import BaseModel
+from models.base.base_trainer import BaseTrainer
+from models.config.config_tool import read_yaml
+from models.line_detect.dataset_LD import WirePointDataset
+from models.line_detect.postprocess import box_line_, show_
+from utils.log_util import show_line, save_last_model, save_best_model
+from tools import utils
+
+
+def _loss(losses):
+    total_loss = 0
+    for i in losses.keys():
+        if i != "loss_wirepoint":
+            total_loss += losses[i]
+        else:
+            loss_labels = losses[i]["losses"]
+    loss_labels_k = list(loss_labels[0].keys())
+    for j, name in enumerate(loss_labels_k):
+        loss = loss_labels[0][name].mean()
+        total_loss += loss
+    return total_loss
+
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def move_to_device(data, device):
+    if isinstance(data, (list, tuple)):
+        return type(data)(move_to_device(item, device) for item in data)
+    elif isinstance(data, dict):
+        return {key: move_to_device(value, device) for key, value in data.items()}
+    elif isinstance(data, torch.Tensor):
+        return data.to(device)
+    else:
+        return data  # 对于非张量类型的数据不做任何改变
+
+
+class Trainer(BaseTrainer):
+    def __init__(self, model=None,
+                 dataset=None,
+                 device='cuda',
+                 freeze_config=None,  # 新增:冻结参数配置
+                 **kwargs):
+        super().__init__(model, dataset, device, **kwargs)
+        self.freeze_config = freeze_config or {}  # 默认冻结配置为空
+    def move_to_device(self, data, device):
+        if isinstance(data, (list, tuple)):
+            return type(data)(self.move_to_device(item, device) for item in data)
+        elif isinstance(data, dict):
+            return {key: self.move_to_device(value, device) for key, value in data.items()}
+        elif isinstance(data, torch.Tensor):
+            return data.to(device)
+        else:
+            return data  # 对于非张量类型的数据不做任何改变
+    def freeze_params(self, model):
+        """根据配置冻结模型参数"""
+        default_config = {
+            'backbone': True,  # 冻结 backbone
+            'rpn': False,      # 不冻结 rpn
+            'roi_heads': {
+                'box_head': False,
+                'box_predictor': False,
+                'line_head': False,
+                'line_predictor': {
+                    'fc1': False,
+                    'fc2': {
+                        '0': False,
+                        '2': False,
+                        '4': False
+                    }
+                }
+            }
+        }
+
+        # 更新默认配置
+        default_config.update(self.freeze_config)
+        config = default_config
+
+        print("\n===== Parameter Freezing Configuration =====")
+        for name, module in model.named_children():
+            if name in config:
+                if isinstance(config[name], bool):
+                    for param in module.parameters():
+                        param.requires_grad = not config[name]
+                    print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
+
+                elif isinstance(config[name], dict):
+                    for subname, submodule in module.named_children():
+                        if subname in config[name]:
+                            if isinstance(config[name][subname], bool):
+                                for param in submodule.parameters():
+                                    param.requires_grad = not config[name][subname]
+                                print(
+                                    f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
+
+                            elif isinstance(config[name][subname], dict):
+                                for subsubname, subsubmodule in submodule.named_children():
+                                    if subsubname in config[name][subname]:
+                                        for param in subsubmodule.parameters():
+                                            param.requires_grad = not config[name][subname][subsubname]
+                                        print(
+                                            f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
+
+        # 打印参数统计
+        total_params = sum(p.numel() for p in model.parameters())
+        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print(f"\nTotal Parameters: {total_params:,}")
+        print(f"Trainable Parameters: {trainable_params:,}")
+        print(f"Frozen Parameters: {total_params - trainable_params:,}")
+
+    def load_best_model(self, model, optimizer, save_path, device):
+        if os.path.exists(save_path):
+            checkpoint = torch.load(save_path, map_location=device)
+            model.load_state_dict(checkpoint['model_state_dict'])
+            if optimizer is not None:
+                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+            epoch = checkpoint['epoch']
+            loss = checkpoint['loss']
+            print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
+        else:
+            print(f"No saved model found at {save_path}")
+        return model, optimizer
+
+    def writer_loss(self, writer, losses, epoch):
+        try:
+            for key, value in losses.items():
+                if key == 'loss_wirepoint':
+                    for subdict in losses['loss_wirepoint']['losses']:
+                        for subkey, subvalue in subdict.items():
+                            writer.add_scalar(f'loss/{subkey}',
+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                              epoch)
+                elif isinstance(value, torch.Tensor):
+                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+        except Exception as e:
+            print(f"TensorBoard logging error: {e}")
+
+    def train_cfg(self, model: BaseModel, cfg, freeze_config=None):  # 新增:支持传入冻结配置
+        cfg = read_yaml(cfg)
+        self.freeze_config = freeze_config or {}  # 更新冻结配置
+        self.train(model, **cfg)
+
+    def train(self, model, **kwargs):
+        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
+        train_sampler = torch.utils.data.RandomSampler(dataset_train)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True)
+        train_collate_fn = utils.collate_fn_wirepoint
+        data_loader_train = torch.utils.data.DataLoader(
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn
+        )
+
+        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
+        val_sampler = torch.utils.data.RandomSampler(dataset_val)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=4, drop_last=True)
+        val_collate_fn = utils.collate_fn_wirepoint
+        data_loader_val = torch.utils.data.DataLoader(
+            dataset_val, batch_sampler=val_batch_sampler, num_workers=1, collate_fn=val_collate_fn
+        )
+
+        train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
+        wts_path = os.path.join(train_result_ptath, 'weights')
+        tb_path = os.path.join(train_result_ptath, 'logs')
+        writer = SummaryWriter(tb_path)
+        model.to(device)
+        # # 加载权重
+        # save_path =r"F:\BaiduNetdiskDownload\r50fpn_wts_e350\best.pth"
+        # model, _ = self.load_best_model(model, None, save_path, device)
+        # 冻结参数
+        # self.freeze_params(model)
+
+        # 初始化优化器(仅训练未冻结参数)
+        optimizer = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=kwargs['optim']['lr']
+        )
+
+        last_model_path = os.path.join(wts_path, 'last.pth')
+        best_model_path = os.path.join(wts_path, 'best.pth')
+        global_step = 0
+
+        for epoch in range(kwargs['optim']['max_epoch']):
+            print(f"epoch:{epoch}")
+            total_train_loss = 0.0
+            model.train()
+            for imgs, targets in data_loader_train:
+                imgs = move_to_device(imgs, device)
+                targets = move_to_device(targets, device)
+                losses = model(imgs, targets)
+                loss = _loss(losses)
+                total_train_loss += loss.item()
+                optimizer.zero_grad()
+                loss.backward()
+                optimizer.step()
+                self.writer_loss(writer, losses, global_step)
+                global_step += 1
+
+            avg_train_loss = total_train_loss / len(data_loader_train)
+            if epoch == 0:
+                best_loss = avg_train_loss
+            writer.add_scalar('loss/train', avg_train_loss, epoch)
+
+            if os.path.exists(f'{wts_path}/last.pt'):
+                os.remove(f'{wts_path}/last.pt')
+            save_last_model(model, last_model_path, epoch, optimizer)
+            best_loss = save_best_model(model, best_model_path, epoch, avg_train_loss, best_loss, optimizer)
+
+            model.eval()
+            with torch.no_grad():
+                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
+                    t_start = time.time()
+                    print(f'start to predict:{t_start}')
+                    pred = model(self.move_to_device(imgs, self.device))
+                    # print(f'pred:{pred}')
+                    t_end = time.time()
+                    print(f'predict used:{t_end - t_start}')
+                    if batch_idx == 0:
+                        show_line(imgs[0], pred, epoch, writer)
+                    break
+
+
+import torch
+
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+if __name__ == '__main__':
+
+    # model = LineNet('line_net.yaml')
+    model=linenet_resnet50_fpn()
+    #model=linenet_resnet18_fpn()
+    # trainer = Trainer()
+    # trainer.train_cfg(model,cfg='./train.yaml')
+    # model.train_by_cfg(cfg='train.yaml')
+    trainer = Trainer()
+    trainer.train_cfg(model=model, cfg='train.yaml')

+ 1 - 1
models/line_detect/train.yaml

@@ -1,6 +1,6 @@
 io:
   logdir: logs/
-  datadir: /root/autodl-tmp/wirenet_rgb_gray
+  datadir: I:/datasets/0322_suanzaisheng
 #  datadir: I:\datasets\wirenet_1000
   resume_from:
   num_workers: 8

+ 2 - 2
models/line_detect/trainer.py

@@ -95,7 +95,7 @@ class Trainer(BaseTrainer):
         dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=64, drop_last=True)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
         train_collate_fn = utils.collate_fn_wirepoint
         data_loader_train = torch.utils.data.DataLoader(
             dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
@@ -104,7 +104,7 @@ class Trainer(BaseTrainer):
         dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
         val_sampler = torch.utils.data.RandomSampler(dataset_val)
         # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=64, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
         val_collate_fn = utils.collate_fn_wirepoint
         data_loader_val = torch.utils.data.DataLoader(
             dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn

+ 1 - 1
utils/log_util.py

@@ -108,7 +108,7 @@ def show_line(img, pred, epoch, writer):
     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
     nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
 
-    for i, t in enumerate([0.85]):
+    for i, t in enumerate([0.001]):
         plt.gca().set_axis_off()
         plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
         plt.margins(0, 0)