RenLiqiang vor 7 Monaten
Ursprung
Commit
28658aaa06

+ 4 - 5
models/base/backbone_factory.py

@@ -2,20 +2,19 @@ from collections import OrderedDict
 
 from libs.vision_libs import models
 from libs.vision_libs.models import mobilenet_v3_large, EfficientNet_V2_S_Weights, efficientnet_v2_s, \
-    EfficientNet_V2_M_Weights, efficientnet_v2_m, EfficientNet_V2_L_Weights, efficientnet_v2_l
+    EfficientNet_V2_M_Weights, efficientnet_v2_m, EfficientNet_V2_L_Weights, efficientnet_v2_l, ConvNeXt_Base_Weights
 from libs.vision_libs.models._utils import _ovewrite_value_param, handle_legacy_interface
 from libs.vision_libs.models.detection import FasterRCNN
 from libs.vision_libs.models.detection.anchor_utils import AnchorGenerator
 from libs.vision_libs.models.detection.ssdlite import _mobilenet_extractor
 from libs.vision_libs.models.resnet import resnet50, ResNet50_Weights, resnet18
-from libs.vision_libs.models.detection._utils import overwrite_eps
+
 from libs.vision_libs.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
 from libs.vision_libs.ops import misc as misc_nn_ops, MultiScaleRoIAlign
 from torch import nn
 
 import torch
-from torchvision.models.detection.backbone_utils import BackboneWithFPN, resnet_fpn_backbone
-from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
+from  libs.vision_libs.models.detection.backbone_utils import BackboneWithFPN
 
 
 def get_resnet50_fpn():
@@ -45,7 +44,7 @@ def get_mobilenet_v3_large_fpn():
     backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
     return backbone
 def get_convnext_fpn():
-    convnext = models.convnext_base(pretrained=True)
+    convnext = models.convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
     # convnext = models.convnext_small(pretrained=True)
     # convnext = models.convnext_large(pretrained=True)
 

+ 1 - 1
models/base/base_model.py

@@ -14,7 +14,7 @@ class BaseModel(ABC, torch.nn.Module):
         self.trainer = None
 
     @abstractmethod
-    def train_by_cfg(self, cfg):
+    def start_train(self, cfg):
         return
 
     # @abstractmethod

+ 1 - 1
models/line_detect/.ipynb_checkpoints/line_net-checkpoint.py

@@ -207,7 +207,7 @@ class LineNet(BaseDetectionNet):
     def train_by_cfg(self, cfg):
         # cfg = read_yaml(cfg)
         self.trainer = Trainer()
-        self.trainer.train_cfg(model=self,cfg=cfg)
+        self.trainer.train_from_cfg(model=self, cfg=cfg)
 
 
 

+ 0 - 286
models/line_detect/111.py

@@ -1,286 +0,0 @@
-import os
-import time
-from datetime import datetime
-import torch
-from torch.utils.tensorboard import SummaryWriter
-from models.base.base_model import BaseModel
-from models.base.base_trainer import BaseTrainer
-from models.config.config_tool import read_yaml
-from models.line_detect.dataset_LD import WirePointDataset
-from models.line_detect.postprocess import box_line_, show_
-from utils.log_util import show_line, save_last_model, save_best_model
-from tools import utils
-
-
-def _loss(losses):
-    total_loss = 0
-    for i in losses.keys():
-        if i != "loss_wirepoint":
-            total_loss += losses[i]
-        else:
-            loss_labels = losses[i]["losses"]
-    loss_labels_k = list(loss_labels[0].keys())
-    for j, name in enumerate(loss_labels_k):
-        loss = loss_labels[0][name].mean()
-        total_loss += loss
-    return total_loss
-
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
-
-def move_to_device(data, device):
-    if isinstance(data, (list, tuple)):
-        return type(data)(move_to_device(item, device) for item in data)
-    elif isinstance(data, dict):
-        return {key: move_to_device(value, device) for key, value in data.items()}
-    elif isinstance(data, torch.Tensor):
-        return data.to(device)
-    else:
-        return data  # 对于非张量类型的数据不做任何改变
-
-
-class Trainer(BaseTrainer):
-    def __init__(self, model=None,
-                 dataset=None,
-                 device='cuda',
-                 freeze_config=None,  # 新增:冻结参数配置
-                 **kwargs):
-        super().__init__(model, dataset, device, **kwargs)
-        self.freeze_config = freeze_config or {}  # 默认冻结配置为空
-
-    def move_to_device(self, data, device):
-        if isinstance(data, (list, tuple)):
-            return type(data)(self.move_to_device(item, device) for item in data)
-        elif isinstance(data, dict):
-            return {key: self.move_to_device(value, device) for key, value in data.items()}
-        elif isinstance(data, torch.Tensor):
-            return data.to(device)
-        else:
-            return data  # 对于非张量类型的数据不做任何改变
-
-    def freeze_params(self, model):
-        """根据配置冻结模型参数"""
-        default_config = {
-            'backbone': True,  # 冻结 backbone
-            'rpn': False,  # 不冻结 rpn
-            'roi_heads': {
-                'box_head': False,
-                'box_predictor': False,
-                'line_head': False,
-                'line_predictor': {
-                    'fc1': False,
-                    'fc2': {
-                        '0': False,
-                        '2': False,
-                        '4': False
-                    }
-                }
-            }
-        }
-
-        # 更新默认配置
-        default_config.update(self.freeze_config)
-        config = default_config
-
-        print("\n===== Parameter Freezing Configuration =====")
-        for name, module in model.named_children():
-            if name in config:
-                if isinstance(config[name], bool):
-                    for param in module.parameters():
-                        param.requires_grad = not config[name]
-                    print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
-
-                elif isinstance(config[name], dict):
-                    for subname, submodule in module.named_children():
-                        if subname in config[name]:
-                            if isinstance(config[name][subname], bool):
-                                for param in submodule.parameters():
-                                    param.requires_grad = not config[name][subname]
-                                print(
-                                    f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
-
-                            elif isinstance(config[name][subname], dict):
-                                for subsubname, subsubmodule in submodule.named_children():
-                                    if subsubname in config[name][subname]:
-                                        for param in subsubmodule.parameters():
-                                            param.requires_grad = not config[name][subname][subsubname]
-                                        print(
-                                            f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
-
-        # 打印参数统计
-        total_params = sum(p.numel() for p in model.parameters())
-        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
-        print(f"\nTotal Parameters: {total_params:,}")
-        print(f"Trainable Parameters: {trainable_params:,}")
-        print(f"Frozen Parameters: {total_params - trainable_params:,}")
-
-    def load_best_model(self, model, optimizer, save_path, device):
-        if os.path.exists(save_path):
-            checkpoint = torch.load(save_path, map_location=device)
-            model.load_state_dict(checkpoint['model_state_dict'])
-            if optimizer is not None:
-                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-            epoch = checkpoint['epoch']
-            loss = checkpoint['loss']
-            print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
-        else:
-            print(f"No saved model found at {save_path}")
-        return model, optimizer
-
-    def writer_loss(self, writer, losses, epoch,mode='train'):
-        try:
-            for key, value in losses.items():
-                if key == 'loss_wirepoint':
-                    for subdict in losses['loss_wirepoint']['losses']:
-                        for subkey, subvalue in subdict.items():
-                            writer.add_scalar(f'{mode}/loss/{subkey}',
-                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
-                                              epoch)
-                elif isinstance(value, torch.Tensor):
-                    writer.add_scalar(f'{mode}/loss/{key}', value.item(), epoch)
-        except Exception as e:
-            print(f"TensorBoard logging error: {e}")
-
-    def train_cfg(self, model: BaseModel, cfg, freeze_config=None):  # 新增:支持传入冻结配置
-        cfg = read_yaml(cfg)
-        self.freeze_config = freeze_config or {}  # 更新冻结配置
-        self.train(model, **cfg)
-
-    def train(self, model, **kwargs):
-        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
-        train_sampler = torch.utils.data.RandomSampler(dataset_train)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
-        train_collate_fn = utils.collate_fn_wirepoint
-        data_loader_train = torch.utils.data.DataLoader(
-            dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn
-        )
-
-        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
-        val_sampler = torch.utils.data.RandomSampler(dataset_val)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
-        val_collate_fn = utils.collate_fn_wirepoint
-        data_loader_val = torch.utils.data.DataLoader(
-            dataset_val, batch_sampler=val_batch_sampler, num_workers=1, collate_fn=val_collate_fn
-        )
-
-        train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
-        wts_path = os.path.join(train_result_ptath, 'weights')
-        tb_path = os.path.join(train_result_ptath, 'logs')
-        writer = SummaryWriter(tb_path)
-        model.to(device)
-        # # 加载权重
-        # save_path =r"F:\BaiduNetdiskDownload\r50fpn_wts_e350\best.pth"
-        # model, _ = self.load_best_model(model, None, save_path, device)
-        # 冻结参数
-        # self.freeze_params(model)
-
-        # 初始化优化器(仅训练未冻结参数)
-        optimizer = torch.optim.Adam(
-            filter(lambda p: p.requires_grad, model.parameters()),
-            lr=kwargs['optim']['lr']
-        )
-
-        last_model_path = os.path.join(wts_path, 'last.pth')
-        best_train_model_path = os.path.join(wts_path, 'best_train.pth')
-        best_val_model_path = os.path.join(wts_path, 'best_val.pth')
-        global_train_step = 0
-        global_val_step = 0
-
-        for epoch in range(kwargs['optim']['max_epoch']):
-            print(f"epoch:{epoch}")
-            total_train_loss = 0.0
-            model.train()
-            for imgs, targets in data_loader_train:
-                imgs = move_to_device(imgs, device)
-                targets = move_to_device(targets, device)
-                losses = model(imgs, targets)
-                loss = _loss(losses)
-                total_train_loss += loss.item()
-                optimizer.zero_grad()
-                loss.backward()
-                optimizer.step()
-                self.writer_loss(writer, losses, global_train_step)
-                global_train_step += 1
-
-
-
-            model.eval()
-            print(f'model.eval!!')
-            # ========== Validation ==========
-            total_val_loss = 0.0
-            batch_idx=0
-            with torch.no_grad():
-                for imgs, targets in data_loader_val:
-                    t_start = time.time()
-                    print(f'start to predict:{t_start}')
-
-                    imgs = move_to_device(imgs, device)
-                    targets = move_to_device(targets, device)
-                    print(f'targets:{targets}')
-
-                    _,losses = model(imgs, targets)
-                    self.writer_loss(writer, losses, global_val_step,mode='val')
-                    global_val_step+=1
-                    print(f'val losses:{losses}')
-                    loss = _loss(losses)
-                    total_val_loss += loss.item()
-
-                    pred= model(self.move_to_device(imgs, self.device))
-
-                    # print(f'pred:{pred}')
-                    t_end = time.time()
-                    print(f'predict used:{t_end - t_start}')
-                    if batch_idx == 0:
-                        show_line(imgs[0], pred, epoch, writer)
-                        batch_idx+=1
-
-
-            avg_val_loss = total_val_loss / len(data_loader_val)
-            # print(f'avg_val_loss:{avg_val_loss}')
-
-            avg_train_loss = total_train_loss / len(data_loader_train)
-            print(f'avg_train_loss:{avg_train_loss}')
-            if epoch == 0:
-                best_train_loss = avg_train_loss
-                best_val_loss = avg_val_loss
-            writer.add_scalar('loss/train', avg_train_loss, epoch)
-
-            if os.path.exists(f'{wts_path}/last.pt'):
-                os.remove(f'{wts_path}/last.pt')
-            save_last_model(model, last_model_path, epoch, optimizer)
-            best_train_loss = save_best_model(model, best_train_model_path, epoch, avg_train_loss, best_train_loss,
-                                              optimizer)
-
-            best_val_loss = save_best_model(model, best_val_model_path, epoch, avg_val_loss, best_val_loss,
-                                              optimizer)
-            writer.add_scalar('loss/val', avg_val_loss, epoch)
-            print(f"Epoch {epoch} - Val Loss: {avg_val_loss:.4f}")
-
-
-import torch
-
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, \
-    get_line_net_efficientnetv2, get_line_net_convnext_fpn
-
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-if __name__ == '__main__':
-    # model = LineNet('line_net.yaml')
-    # model = linenet_resnet50_fpn().to(device)
-    model=linenet_resnet18_fpn().to(device)
-    # model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device)
-    # model=get_line_net_convnext_fpn(num_classes=2).to(device)
-    # model=linenet_resnet18_fpn()
-    trainer = Trainer()
-    trainer.train_cfg(model,cfg='./train.yaml')
-    model.train_by_cfg(cfg='train.yaml')
-    # trainer = Trainer()
-    # trainer.train_cfg(model=model, cfg='train.yaml')
-    #
-    # pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth"
-    # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
-    # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True)
-
-    # model = model.load_best_model(model, r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth")
-    # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png"
-    # model.predict1(model, img_path, type=1, threshold=0, save_path=None, show=True)

+ 2 - 2
models/line_detect/line_net.py

@@ -214,10 +214,10 @@ class LineNet(BaseDetectionNet):
         # self.roi_heads.line_head = line_head
         # self.roi_heads.line_predictor = line_predictor
 
-    def train_by_cfg(self, cfg):
+    def start_train(self, cfg):
         # cfg = read_yaml(cfg)
         self.trainer = Trainer()
-        self.trainer.train_cfg(model=self, cfg=cfg)
+        self.trainer.train_from_cfg(model=self, cfg=cfg)
 
     def load_best_model(self,model,  save_path, device='cuda'):
         if os.path.exists(save_path):

+ 1 - 1
models/line_detect/roi_heads.py

@@ -4,7 +4,7 @@ import torch
 import torch.nn.functional as F
 import torchvision
 from torch import nn, Tensor
-from torchvision.ops import boxes as box_ops, roi_align
+from  libs.vision_libs.ops import boxes as box_ops, roi_align
 
 import libs.vision_libs.models.detection._utils as det_utils
 

+ 1 - 1
models/line_detect/test_tiff.py

@@ -25,7 +25,7 @@ def pointscloud2depthmap(points):
 
             # 检查是否在图像边界内
             if 0 <= u < width and 0 <= v < height:
-                point_image[v, u, :] = point
+                point_image[v, u, :] = (X,Y,Z)
 
     return point_image
 def pointscloud2colorimg(points):

+ 31 - 10
models/line_detect/train.yaml

@@ -1,17 +1,38 @@
 io:
-  logdir: logs/
+  logdir: train_results
   datadir: \\192.168.50.222/share/zyh/513train/a_dataset
 #  datadir: D:\python\PycharmProjects\data_20250223\0423_
 #  datadir: I:\datasets\wirenet_1000
-  resume_from:
-  num_workers: 8
+
   tensorboard_port: 6000
   validation_interval: 300
 
-optim:
-  name: Adam
-  lr: 4.0e-4
-  amsgrad: True
-  weight_decay: 1.0e-4
-  max_epoch: 90000000
-  lr_decay_epoch: 10
+train_params:
+  resume_from:
+  num_workers: 8
+  batch_size: 4
+  max_epoch: 80000
+  optim:
+    name: Adam
+    lr: 4.0e-4
+    amsgrad: True
+    weight_decay: 1.0e-4
+    lr_decay_epoch: 10
+
+#  冻结参数
+  freeze_params:
+    backbone: False,
+    rpn: False,
+    roi_heads:
+      box_head: False,
+      box_predictor: False,
+      line_head: False,
+      line_predictor:
+        fc1: False,
+        fc2:
+          0: False,
+          2: False,
+          4: False
+
+
+

+ 5 - 4
models/line_detect/test_train.py → models/line_detect/train_demo.py

@@ -1,14 +1,15 @@
 import torch
 
-from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn
+from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, get_line_net_convnext_fpn
 from models.line_detect.trainer import Trainer
 
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 if __name__ == '__main__':
 
     # model = LineNet('line_net.yaml')
-    model=linenet_resnet50_fpn()
-    #model=linenet_resnet18_fpn()
+    # model=linenet_resnet50_fpn()
+    # model=get_line_net_convnext_fpn(num_classes=2).to(device)
+    model=linenet_resnet18_fpn()
     # trainer = Trainer()
     # trainer.train_cfg(model,cfg='./train.yaml')
-    model.train_by_cfg(cfg='train.yaml')
+    model.start_train(cfg='train.yaml')

+ 253 - 88
models/line_detect/trainer.py

@@ -1,18 +1,20 @@
-
 import os
 import time
 from datetime import datetime
 
+import numpy as np
 import torch
+from matplotlib import pyplot as plt
 from torch.utils.tensorboard import SummaryWriter
 
+from libs.vision_libs.utils import draw_bounding_boxes
 from models.base.base_model import BaseModel
 from models.base.base_trainer import BaseTrainer
 from models.config.config_tool import read_yaml
 from models.line_detect.dataset_LD import WirePointDataset
-from models.line_detect.postprocess import box_line_, show_
-from utils.log_util import show_line, save_last_model, save_best_model
+from models.wirenet.postprocess import postprocess
 from tools import utils
+from torchvision import transforms
 
 
 def _loss(losses):
@@ -26,26 +28,38 @@ def _loss(losses):
     for j, name in enumerate(loss_labels_k):
         loss = loss_labels[0][name].mean()
         total_loss += loss
-
     return total_loss
+
+
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-def move_to_device(data, device):
-    if isinstance(data, (list, tuple)):
-        return type(data)(move_to_device(item, device) for item in data)
-    elif isinstance(data, dict):
-        return {key: move_to_device(value, device) for key, value in data.items()}
-    elif isinstance(data, torch.Tensor):
-        return data.to(device)
-    else:
-        return data  # 对于非张量类型的数据不做任何改变
+
 
 class Trainer(BaseTrainer):
-    def __init__(self, model=None,
-                 dataset=None,
-                 device='cuda',
-                 **kwargs):
+    def __init__(self, model=None, **kwargs):
+        super().__init__(model, device, **kwargs)
+        self.model = model
+        print(f'kwargs:{kwargs}')
+        self.init_params(**kwargs)
 
-        super().__init__(model,dataset,device,**kwargs)
+    def init_params(self, **kwargs):
+        if kwargs != {}:
+            print(f'train_params:{kwargs["train_params"]}')
+            self.freeze_config = kwargs['train_params']['freeze_params']
+            print(f'freeze_config:{self.freeze_config}')
+            self.dataset_path = kwargs['io']['datadir']
+            self.batch_size = kwargs['train_params']['batch_size']
+            self.num_workers = kwargs['train_params']['num_workers']
+            self.logdir = kwargs['io']['logdir']
+            self.resume_from = kwargs['train_params']['resume_from']
+            self.optim = ''
+            self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
+            self.wts_path = os.path.join(self.train_result_ptath, 'weights')
+            self.tb_path = os.path.join(self.train_result_ptath, 'logs')
+            self.writer = SummaryWriter(self.tb_path)
+            self.last_model_path = os.path.join(self.wts_path, 'last.pth')
+            self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth')
+            self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth')
+            self.max_epoch = kwargs['train_params']['max_epoch']
 
     def move_to_device(self, data, device):
         if isinstance(data, (list, tuple)):
@@ -57,7 +71,63 @@ class Trainer(BaseTrainer):
         else:
             return data  # 对于非张量类型的数据不做任何改变
 
-    def load_best_model(self,model, optimizer, save_path, device):
+    def freeze_params(self, model):
+        """根据配置冻结模型参数"""
+        default_config = {
+            'backbone': True,  # 冻结 backbone
+            'rpn': False,  # 不冻结 rpn
+            'roi_heads': {
+                'box_head': False,
+                'box_predictor': False,
+                'line_head': False,
+                'line_predictor': {
+                    'fc1': False,
+                    'fc2': {
+                        '0': False,
+                        '2': False,
+                        '4': False
+                    }
+                }
+            }
+        }
+
+        # 更新默认配置
+        default_config.update(self.freeze_config)
+        config = default_config
+
+        print("\n===== Parameter Freezing Configuration =====")
+        for name, module in model.named_children():
+            if name in config:
+                if isinstance(config[name], bool):
+                    for param in module.parameters():
+                        param.requires_grad = not config[name]
+                    print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}")
+
+                elif isinstance(config[name], dict):
+                    for subname, submodule in module.named_children():
+                        if subname in config[name]:
+                            if isinstance(config[name][subname], bool):
+                                for param in submodule.parameters():
+                                    param.requires_grad = not config[name][subname]
+                                print(
+                                    f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}")
+
+                            elif isinstance(config[name][subname], dict):
+                                for subsubname, subsubmodule in submodule.named_children():
+                                    if subsubname in config[name][subname]:
+                                        for param in subsubmodule.parameters():
+                                            param.requires_grad = not config[name][subname][subsubname]
+                                        print(
+                                            f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}")
+
+        # 打印参数统计
+        total_params = sum(p.numel() for p in model.parameters())
+        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print(f"\nTotal Parameters: {total_params:,}")
+        print(f"Trainable Parameters: {trainable_params:,}")
+        print(f"Frozen Parameters: {total_params - trainable_params:,}")
+
+    def load_best_model(self, model, optimizer, save_path, device):
         if os.path.exists(save_path):
             checkpoint = torch.load(save_path, map_location=device)
             model.load_state_dict(checkpoint['model_state_dict'])
@@ -70,111 +140,206 @@ class Trainer(BaseTrainer):
             print(f"No saved model found at {save_path}")
         return model, optimizer
 
-    def writer_loss(self, writer, losses, epoch):
+    def writer_predict_result(self, img, result, epoch):
+        img = img.cpu().detach()
+        im = img.permute(1, 2, 0)
+        self.writer.add_image("z-ori", im, epoch, dataformats="HWC")
+
+        boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"],
+                                          colors="yellow", width=1)
+        self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
+
+        PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
+        # print(f'pred[1]:{pred[1]}')
+        heatmaps = result[-2][0]
+        print(f'heatmaps:{heatmaps.shape}')
+        jmap = heatmaps[1: 2].cpu().detach()
+        lmap = heatmaps[2: 3].cpu().detach()
+        self.writer.add_image("z-jmap", jmap, epoch)
+        self.writer.add_image("z-lmap", lmap, epoch)
+        # plt.imshow(lmap)
+        # plt.show()
+        H = result[-1]['wires']
+        lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
+        scores = H["score"][0].cpu().numpy()
+        for i in range(1, len(lines)):
+            if (lines[i] == lines[0]).all():
+                lines = lines[:i]
+                scores = scores[:i]
+                break
+
+        # postprocess lines to remove overlapped lines
+        diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
+        nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
+
+        for i, t in enumerate([0]):
+            plt.gca().set_axis_off()
+            plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+            plt.margins(0, 0)
+            for (a, b), s in zip(nlines, nscores):
+                if s < t:
+                    continue
+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
+                plt.scatter(a[1], a[0], **PLTOPTS)
+                plt.scatter(b[1], b[0], **PLTOPTS)
+            plt.gca().xaxis.set_major_locator(plt.NullLocator())
+            plt.gca().yaxis.set_major_locator(plt.NullLocator())
+            plt.imshow(im)
+            plt.tight_layout()
+            fig = plt.gcf()
+            fig.canvas.draw()
+
+            width, height = fig.get_size_inches() * fig.get_dpi()  # 获取图像尺寸
+            tmp_img = fig.canvas.tostring_argb()
+            tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
+            tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
+
+            img_rgb = tmp_img_np[:, :, 1:]  # 提取RGB部分,忽略Alpha通道
+
+            # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
+            #     fig.canvas.get_width_height()[::-1] + (3,))
+            plt.close()
+
+            img2 = transforms.ToTensor()(img_rgb)
+
+            self.writer.add_image("z-output", img2, epoch)
+
+    def writer_loss(self, losses, epoch, phase='train'):
         try:
             for key, value in losses.items():
                 if key == 'loss_wirepoint':
                     for subdict in losses['loss_wirepoint']['losses']:
                         for subkey, subvalue in subdict.items():
-                            writer.add_scalar(f'loss/{subkey}',
-                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue,
-                                              epoch)
+                            self.writer.add_scalar(f'{phase}/loss/{subkey}',
+                                                   subvalue.item() if hasattr(subvalue, 'item') else subvalue,
+                                                   epoch)
                 elif isinstance(value, torch.Tensor):
-                    writer.add_scalar(f'loss/{key}', value.item(), epoch)
+                    self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch)
         except Exception as e:
             print(f"TensorBoard logging error: {e}")
 
-    def train_cfg(self, model:BaseModel, cfg):
-        # cfg = r'./config/wireframe.yaml'
+    def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None):  # 新增:支持传入冻结配置
         cfg = read_yaml(cfg)
-        print(f'cfg:{cfg}')
-        # print(cfg['n_dyn_negl'])
+        # print(f'cfg:{cfg}')
+        # self.freeze_config = freeze_config or {}  # 更新冻结配置
+
         self.train(model, **cfg)
 
     def train(self, model, **kwargs):
-        dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train')
+
+        self.init_params(**kwargs)
+
+        dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train')
+        dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val')
+
         train_sampler = torch.utils.data.RandomSampler(dataset_train)
-        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True)
-        train_collate_fn = utils.collate_fn_wirepoint
+        val_sampler = torch.utils.data.RandomSampler(dataset_val)
+        train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True)
+        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True)
+        train_collate_fn = utils.collate_fn
+        val_collate_fn = utils.collate_fn
+
         data_loader_train = torch.utils.data.DataLoader(
-            dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
+            dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn
         )
-
-        dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val')
-        val_sampler = torch.utils.data.RandomSampler(dataset_val)
-        # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
-        val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True)
-        val_collate_fn = utils.collate_fn_wirepoint
         data_loader_val = torch.utils.data.DataLoader(
-            dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
+            dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn
         )
 
-        train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S"))
-        wts_path = os.path.join(train_result_ptath, 'weights')
-        tb_path = os.path.join(train_result_ptath, 'logs')
-        writer = SummaryWriter(tb_path)
-
-        optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['optim']['lr'])
-        # writer = SummaryWriter(kwargs['io']['logdir'])
         model.to(device)
 
+        optimizer = torch.optim.Adam(
+            filter(lambda p: p.requires_grad, model.parameters()),
+            lr=kwargs['train_params']['optim']['lr']
+        )
 
+        for epoch in range(self.max_epoch):
+            print(f"train epoch:{epoch}")
 
-        # # 加载权重
-        # save_path = 'logs/pth/best_model.pth'
-        # model, optimizer = self.load_best_model(model, optimizer, save_path, device)
+            model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer)
 
-        # logdir_with_pth = os.path.join(kwargs['io']['logdir'], 'pth')
-        # os.makedirs(logdir_with_pth, exist_ok=True)  # 创建目录(如果不存在)
-        last_model_path = os.path.join(wts_path, 'last.pth')
-        best_model_path = os.path.join(wts_path, 'best.pth')
-        global_step = 0
+            # ========== Validation ==========
+            with torch.no_grad():
+                model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val', )
 
-        for epoch in range(kwargs['optim']['max_epoch']):
-            print(f"epoch:{epoch}")
-            total_train_loss = 0.0
+            self.save_last_model(model, epoch, optimizer)
+            best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss,
+                                                   best_train_loss,
+                                                   optimizer)
+            best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss,
+                                                 optimizer)
 
+    def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'):
+        if phase == 'train':
             model.train()
+        if phase == 'val':
+            model.eval
 
-            for imgs, targets in data_loader_train:
-                imgs = move_to_device(imgs, device)
-                targets=move_to_device(targets,device)
-                # print(f'imgs:{len(imgs)}')
-                # print(f'targets:{len(targets)}')
-                losses = model(imgs, targets)
-                loss = _loss(losses)
-                total_train_loss += loss.item()
+        total_loss = 0
+        epoch_step = 0
+        global_step = epoch_step * len(data_loader)
+        for imgs, targets in data_loader:
+            imgs = self.move_to_device(imgs, device)
+            targets = self.move_to_device(targets, device)
+            losses = model(imgs, targets)
+            loss = _loss(losses)
+            total_loss += loss.item()
+            if phase == 'train':
                 optimizer.zero_grad()
                 loss.backward()
                 optimizer.step()
-                self.writer_loss(writer, losses, global_step)
-                global_step+=1
+            self.writer_loss(losses, global_step, phase=phase)
+            global_step += 1
 
+            if epoch_step == 0 and phase == 'val':
+                t_start = time.time()
+                print(f'start to predict:{t_start}')
+                result = model(self.move_to_device(imgs, self.device))
+                t_end = time.time()
+                print(f'predict used:{t_end - t_start}')
+                self.writer_predict_result(img=imgs[0], result=result, epoch=epoch)
 
-            avg_train_loss = total_train_loss / len(data_loader_train)
-            if epoch == 0:
-                best_loss = avg_train_loss;
+        avg_loss = total_loss / len(data_loader)
+        print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}')
+        self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch)
+        return model, avg_loss
 
-            writer.add_scalar('loss/train', avg_train_loss, epoch)
+    def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None):
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
 
+        if current_loss <= best_loss:
+            checkpoint = {
+                'epoch': epoch,
+                'model_state_dict': model.state_dict(),
+                'loss': current_loss
+            }
+            if optimizer is not None:
+                checkpoint['optimizer_state_dict'] = optimizer.state_dict()
 
-            if os.path.exists(f'{wts_path}/last.pt'):
-                os.remove(f'{wts_path}/last.pt')
-            # torch.save(model.state_dict(), f'{wts_path}/last.pt')
-            save_last_model(model,last_model_path,epoch,optimizer)
-            best_loss = save_best_model(model,best_model_path,epoch,avg_train_loss,best_loss,optimizer)
+            torch.save(checkpoint, save_path)
+            print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
 
-            model.eval()
-            with torch.no_grad():
-                for batch_idx, (imgs, targets) in enumerate(data_loader_val):
-                    t_start = time.time()
-                    print(f'start to predict:{t_start}')
-                    pred = model(self.move_to_device(imgs, self.device))
-                    t_end = time.time()
-                    print(f'predict used:{t_end - t_start}')
-                    if batch_idx == 0:
-                        show_line(imgs[0], pred, epoch, writer)
-                    break
+            return current_loss
+
+        return best_loss
+
+    def save_last_model(self, model, save_path, epoch, optimizer=None):
+
+        if os.path.exists(f'{self.wts_path}/last.pt'):
+            os.remove(f'{self.wts_path}/last.pt')
+
+        os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+        checkpoint = {
+            'epoch': epoch,
+            'model_state_dict': model.state_dict(),
+        }
+
+        if optimizer is not None:
+            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
+
+        torch.save(checkpoint, save_path)
 
 
+if __name__ == '__main__':
+    print('')

+ 1 - 19
utils/log_util.py

@@ -66,25 +66,7 @@ def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=
     return best_loss
 
 
-# def show_line(img, pred, epoch, writer):
-#     fig = plt.figure(figsize=(15, 15))
-#
-#     # ... your plotting code here ...
-#
-#     # Save the figure to a BytesIO buffer
-#     buf = BytesIO()
-#     plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
-#     buf.seek(0)
-#
-#     # Load the image from the buffer and convert to numpy array
-#     image = Image.open(buf)
-#     image_from_plot = np.array(image)[..., :3]  # Keep RGB channels if there's an alpha
-#
-#     # Close the figure to free memory
-#     plt.close(fig)
-#
-#     # Log the image to TensorBoard or other logger
-#     writer.add_image('validate', image_from_plot, epoch, dataformats='HWC')
+
 def show_line(img, pred, epoch, writer):
     img=img.cpu().detach()
     im = img.permute(1, 2, 0)