import os import time from datetime import datetime import torch from torch.utils.tensorboard import SummaryWriter from models.base.base_model import BaseModel from models.base.base_trainer import BaseTrainer from models.config.config_tool import read_yaml from models.line_detect.dataset_LD import WirePointDataset from models.line_detect.postprocess import box_line_, show_ from utils.log_util import show_line, save_last_model, save_best_model from tools import utils def _loss(losses): total_loss = 0 for i in losses.keys(): if i != "loss_wirepoint": total_loss += losses[i] else: loss_labels = losses[i]["losses"] loss_labels_k = list(loss_labels[0].keys()) for j, name in enumerate(loss_labels_k): loss = loss_labels[0][name].mean() total_loss += loss return total_loss device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def move_to_device(data, device): if isinstance(data, (list, tuple)): return type(data)(move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 class Trainer(BaseTrainer): def __init__(self, model=None, dataset=None, device='cuda', freeze_config=None, # 新增:冻结参数配置 **kwargs): super().__init__(model, dataset, device, **kwargs) self.freeze_config = freeze_config or {} # 默认冻结配置为空 def move_to_device(self, data, device): if isinstance(data, (list, tuple)): return type(data)(self.move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: self.move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 def freeze_params(self, model): """根据配置冻结模型参数""" default_config = { 'backbone': True, # 冻结 backbone 'rpn': False, # 不冻结 rpn 'roi_heads': { 'box_head': False, 'box_predictor': False, 'line_head': False, 'line_predictor': { 'fc1': False, 'fc2': { '0': False, '2': False, '4': False } } } } # 更新默认配置 default_config.update(self.freeze_config) config = default_config print("\n===== Parameter Freezing Configuration =====") for name, module in model.named_children(): if name in config: if isinstance(config[name], bool): for param in module.parameters(): param.requires_grad = not config[name] print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}") elif isinstance(config[name], dict): for subname, submodule in module.named_children(): if subname in config[name]: if isinstance(config[name][subname], bool): for param in submodule.parameters(): param.requires_grad = not config[name][subname] print( f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}") elif isinstance(config[name][subname], dict): for subsubname, subsubmodule in submodule.named_children(): if subsubname in config[name][subname]: for param in subsubmodule.parameters(): param.requires_grad = not config[name][subname][subsubname] print( f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}") # 打印参数统计 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nTotal Parameters: {total_params:,}") print(f"Trainable Parameters: {trainable_params:,}") print(f"Frozen Parameters: {total_params - trainable_params:,}") def load_best_model(self, model, optimizer, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model, optimizer def writer_loss(self, writer, losses, epoch): try: for key, value in losses.items(): if key == 'loss_wirepoint': for subdict in losses['loss_wirepoint']['losses']: for subkey, subvalue in subdict.items(): writer.add_scalar(f'loss/{subkey}', subvalue.item() if hasattr(subvalue, 'item') else subvalue, epoch) elif isinstance(value, torch.Tensor): writer.add_scalar(f'loss/{key}', value.item(), epoch) except Exception as e: print(f"TensorBoard logging error: {e}") def train_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置 cfg = read_yaml(cfg) self.freeze_config = freeze_config or {} # 更新冻结配置 self.train(model, **cfg) def train(self, model, **kwargs): dataset_train = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='train') train_sampler = torch.utils.data.RandomSampler(dataset_train) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=4, drop_last=True) train_collate_fn = utils.collate_fn_wirepoint data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_sampler=train_batch_sampler, num_workers=1, collate_fn=train_collate_fn ) dataset_val = WirePointDataset(dataset_path=kwargs['io']['datadir'], dataset_type='val') val_sampler = torch.utils.data.RandomSampler(dataset_val) val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=4, drop_last=True) val_collate_fn = utils.collate_fn_wirepoint data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_sampler=val_batch_sampler, num_workers=1, collate_fn=val_collate_fn ) train_result_ptath = os.path.join('train_results', datetime.now().strftime("%Y%m%d_%H%M%S")) wts_path = os.path.join(train_result_ptath, 'weights') tb_path = os.path.join(train_result_ptath, 'logs') writer = SummaryWriter(tb_path) model.to(device) # # 加载权重 # save_path =r"F:\BaiduNetdiskDownload\r50fpn_wts_e350\best.pth" # model, _ = self.load_best_model(model, None, save_path, device) # 冻结参数 # self.freeze_params(model) # 初始化优化器(仅训练未冻结参数) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=kwargs['optim']['lr'] ) last_model_path = os.path.join(wts_path, 'last.pth') best_model_path = os.path.join(wts_path, 'best.pth') global_step = 0 for epoch in range(kwargs['optim']['max_epoch']): print(f"epoch:{epoch}") total_train_loss = 0.0 model.train() for imgs, targets in data_loader_train: imgs = move_to_device(imgs, device) targets = move_to_device(targets, device) losses = model(imgs, targets) loss = _loss(losses) total_train_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() self.writer_loss(writer, losses, global_step) global_step += 1 avg_train_loss = total_train_loss / len(data_loader_train) if epoch == 0: best_loss = avg_train_loss writer.add_scalar('loss/train', avg_train_loss, epoch) if os.path.exists(f'{wts_path}/last.pt'): os.remove(f'{wts_path}/last.pt') save_last_model(model, last_model_path, epoch, optimizer) best_loss = save_best_model(model, best_model_path, epoch, avg_train_loss, best_loss, optimizer) model.eval() with torch.no_grad(): for batch_idx, (imgs, targets) in enumerate(data_loader_val): t_start = time.time() print(f'start to predict:{t_start}') pred = model(self.move_to_device(imgs, self.device)) # print(f'pred:{pred}') t_end = time.time() print(f'predict used:{t_end - t_start}') if batch_idx == 0: show_line(imgs[0], pred, epoch, writer) break import torch from models.line_detect.line_net import linenet_resnet50_fpn, LineNet, linenet_resnet18_fpn, get_line_net_efficientnetv2 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if __name__ == '__main__': # model = LineNet('line_net.yaml') model = linenet_resnet50_fpn().to(device) # model=get_line_net_efficientnetv2(2, pretrained_backbone=True).to(device) # model=linenet_resnet18_fpn() trainer = Trainer() trainer.train_cfg(model,cfg='./train.yaml') model.train_by_cfg(cfg='train.yaml') # trainer = Trainer() # trainer.train_cfg(model=model, cfg='train.yaml') # # pt_path = r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth" # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png" # model.predict(pt_path, model, img_path, type=1, threshold=0, save_path=None, show=True) # model = model.load_best_model(model, r"\\192.168.50.222\share\lm\weight\20250424_163159——1\weights\best.pth") # img_path = r"D:\python\PycharmProjects\data_20250223\0423_\images\val\2025-04-23-09-43-28_SaveLeftImage.png" # model.predict1(model, img_path, type=1, threshold=0, save_path=None, show=True)