import os import time from datetime import datetime import numpy as np import torch from matplotlib import pyplot as plt from torch.utils.tensorboard import SummaryWriter from libs.vision_libs.utils import draw_bounding_boxes from models.base.base_model import BaseModel from models.base.base_trainer import BaseTrainer from models.config.config_tool import read_yaml from models.line_net.dataset_LD import WirePointDataset from models.wirenet.postprocess import postprocess from tools import utils from torchvision import transforms import matplotlib as mpl cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def _loss(losses): total_loss = 0 for i in losses.keys(): if i != "loss_wirepoint": total_loss += losses[i] else: loss_labels = losses[i]["losses"] loss_labels_k = list(loss_labels[0].keys()) for j, name in enumerate(loss_labels_k): loss = loss_labels[0][name].mean() total_loss += loss return total_loss def c(x): return sm.to_rgba(x) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class Trainer(BaseTrainer): def __init__(self, model=None, **kwargs): super().__init__(model, device, **kwargs) self.model = model # print(f'kwargs:{kwargs}') self.init_params(**kwargs) def init_params(self, **kwargs): if kwargs != {}: print(f'train_params:{kwargs["train_params"]}') self.freeze_config = kwargs['train_params']['freeze_params'] print(f'freeze_config:{self.freeze_config}') self.dataset_path = kwargs['io']['datadir'] self.batch_size = kwargs['train_params']['batch_size'] self.num_workers = kwargs['train_params']['num_workers'] self.logdir = kwargs['io']['logdir'] self.resume_from = kwargs['train_params']['resume_from'] self.optim = '' self.train_result_ptath = os.path.join(self.logdir, datetime.now().strftime("%Y%m%d_%H%M%S")) self.wts_path = os.path.join(self.train_result_ptath, 'weights') self.tb_path = os.path.join(self.train_result_ptath, 'logs') self.writer = SummaryWriter(self.tb_path) self.last_model_path = os.path.join(self.wts_path, 'last.pth') self.best_train_model_path = os.path.join(self.wts_path, 'best_train.pth') self.best_val_model_path = os.path.join(self.wts_path, 'best_val.pth') self.max_epoch = kwargs['train_params']['max_epoch'] def move_to_device(self, data, device): if isinstance(data, (list, tuple)): return type(data)(self.move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: self.move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 def freeze_params(self, model): """根据配置冻结模型参数""" default_config = { 'backbone': True, # 冻结 backbone 'rpn': False, # 不冻结 rpn 'roi_heads': { 'box_head': False, 'box_predictor': False, 'line_head': False, 'line_predictor': { 'fc1': False, 'fc2': { '0': False, '2': False, '4': False } } } } # 更新默认配置 default_config.update(self.freeze_config) config = default_config print("\n===== Parameter Freezing Configuration =====") for name, module in model.named_children(): if name in config: if isinstance(config[name], bool): for param in module.parameters(): param.requires_grad = not config[name] print(f"{'Frozen' if config[name] else 'Trainable'} module: {name}") elif isinstance(config[name], dict): for subname, submodule in module.named_children(): if subname in config[name]: if isinstance(config[name][subname], bool): for param in submodule.parameters(): param.requires_grad = not config[name][subname] print( f"{'Frozen' if config[name][subname] else 'Trainable'} submodule: {name}.{subname}") elif isinstance(config[name][subname], dict): for subsubname, subsubmodule in submodule.named_children(): if subsubname in config[name][subname]: for param in subsubmodule.parameters(): param.requires_grad = not config[name][subname][subsubname] print( f"{'Frozen' if config[name][subname][subsubname] else 'Trainable'} sub-submodule: {name}.{subname}.{subsubname}") # 打印参数统计 total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nTotal Parameters: {total_params:,}") print(f"Trainable Parameters: {trainable_params:,}") print(f"Frozen Parameters: {total_params - trainable_params:,}") def load_best_model(self, model, optimizer, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model, optimizer def writer_predict_result(self, img, result, epoch): img = img.cpu().detach() im = img.permute(1, 2, 0) self.writer.add_image("z-ori", im, epoch, dataformats="HWC") boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), result[0]["boxes"], colors="yellow", width=1) self.writer.add_image("z-boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} # print(f'pred[1]:{pred[1]}') heatmaps = result[-2][0] print(f'heatmaps:{heatmaps.shape}') jmap = heatmaps[1: 2].cpu().detach() lmap = heatmaps[2: 3].cpu().detach() self.writer.add_image("z-jmap", jmap, epoch) self.writer.add_image("z-lmap", lmap, epoch) # plt.imshow(lmap) # plt.show() H = result[-1]['wires'] # lines = H["lines"][0].cpu().numpy() lines=result[0]["lines"] scores =100 for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) for i, t in enumerate([0]): plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) for (a, b), s in zip(nlines, nscores): if s < t: continue plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) plt.scatter(a[1], a[0], **PLTOPTS) plt.scatter(b[1], b[0], **PLTOPTS) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im) plt.tight_layout() fig = plt.gcf() fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸 tmp_img = fig.canvas.tostring_argb() tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8) tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4) img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道 # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape( # fig.canvas.get_width_height()[::-1] + (3,)) plt.close() img2 = transforms.ToTensor()(img_rgb) self.writer.add_image("z-output", img2, epoch) def writer_loss(self, losses, epoch, phase='train'): try: for key, value in losses.items(): if key == 'loss_wirepoint': for subdict in losses['loss_wirepoint']['losses']: for subkey, subvalue in subdict.items(): self.writer.add_scalar(f'{phase}/loss/{subkey}', subvalue.item() if hasattr(subvalue, 'item') else subvalue, epoch) elif isinstance(value, torch.Tensor): self.writer.add_scalar(f'{phase}/loss/{key}', value.item(), epoch) except Exception as e: print(f"TensorBoard logging error: {e}") def train_from_cfg(self, model: BaseModel, cfg, freeze_config=None): # 新增:支持传入冻结配置 cfg = read_yaml(cfg) # print(f'cfg:{cfg}') # self.freeze_config = freeze_config or {} # 更新冻结配置 self.train(model, **cfg) def train(self, model, **kwargs): self.init_params(**kwargs) dataset_train = WirePointDataset(dataset_path=self.dataset_path, dataset_type='train') dataset_val = WirePointDataset(dataset_path=self.dataset_path, dataset_type='val') train_sampler = torch.utils.data.RandomSampler(dataset_train) val_sampler = torch.utils.data.RandomSampler(dataset_val) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=self.batch_size, drop_last=True) val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=self.batch_size, drop_last=True) train_collate_fn = utils.collate_fn val_collate_fn = utils.collate_fn data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_sampler=train_batch_sampler, num_workers=self.num_workers, collate_fn=train_collate_fn ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_sampler=val_batch_sampler, num_workers=self.num_workers, collate_fn=val_collate_fn ) model.to(device) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=kwargs['train_params']['optim']['lr'] ) for epoch in range(self.max_epoch): print(f"train epoch:{epoch}") model, epoch_train_loss = self.one_epoch(model, data_loader_train, epoch, optimizer) # ========== Validation ========== with torch.no_grad(): model, epoch_val_loss = self.one_epoch(model, data_loader_val, epoch, optimizer, phase='val') if epoch==0: best_train_loss = epoch_train_loss best_val_loss = epoch_val_loss self.save_last_model(model,self.last_model_path, epoch, optimizer) best_train_loss = self.save_best_model(model, self.best_train_model_path, epoch, epoch_train_loss, best_train_loss, optimizer) best_val_loss = self.save_best_model(model, self.best_val_model_path, epoch, epoch_val_loss, best_val_loss, optimizer) def one_epoch(self, model, data_loader, epoch, optimizer, phase='train'): if phase == 'train': model.train() if phase == 'val': model.eval() total_loss = 0 epoch_step = 0 global_step = epoch * len(data_loader) for imgs, targets in data_loader: imgs = self.move_to_device(imgs, device) targets = self.move_to_device(targets, device) if phase== 'val': result,losses = model(imgs, targets) else: losses = model(imgs, targets) loss = _loss(losses) total_loss += loss.item() if phase == 'train': optimizer.zero_grad() loss.backward() optimizer.step() self.writer_loss(losses, global_step, phase=phase) global_step += 1 if epoch_step == 0 and phase == 'val': t_start = time.time() print(f'start to predict:{t_start}') result = model(self.move_to_device(imgs, self.device)) t_end = time.time() print(f'predict used:{t_end - t_start}') self.writer_predict_result(img=imgs[0], result=result, epoch=epoch) epoch_step+=1 avg_loss = total_loss / len(data_loader) print(f'{phase}/loss epoch{epoch}:{avg_loss:4f}') self.writer.add_scalar(f'loss/{phase}', avg_loss, epoch) return model, avg_loss def save_best_model(self, model, save_path, epoch, current_loss, best_loss, optimizer=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) if current_loss <= best_loss: checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': current_loss } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, save_path) print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}") return current_loss return best_loss def save_last_model(self, model, save_path, epoch, optimizer=None): if os.path.exists(f'{self.wts_path}/last.pt'): os.remove(f'{self.wts_path}/last.pt') os.makedirs(os.path.dirname(save_path), exist_ok=True) checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, save_path) if __name__ == '__main__': print('')