# 根据LCNN写的train 2025/2/7 ''' #!/usr/bin/env python3 import datetime import glob import os import os.path as osp import platform import pprint import random import shlex import shutil import subprocess import sys import numpy as np import torch import torchvision import yaml import lcnn from lcnn.config import C, M from lcnn.datasets import WireframeDataset, collate from lcnn.models.line_vectorizer import LineVectorizer from lcnn.models.multitask_learner import MultitaskHead, MultitaskLearner from torchvision.models import resnet50 from models.line_detect.line_rcnn import linercnn_resnet50_fpn def main(): # 训练配置参数 config = { # 数据集配置 'datadir': r'D:\python\PycharmProjects\data', # 数据集目录 'config_file': 'config/wireframe.yaml', # 配置文件路径 # GPU配置 'devices': '0', # 使用的GPU设备 'identifier': 'fasterrcnn_resnet50', # 训练标识符 stacked_hourglass unet # 预训练模型路径 # 'pretrained_model': './190418-201834-f8934c6-lr4d10-312k.pth', # 预训练模型路径 } # 更新配置 C.update(C.from_yaml(filename=config['config_file'])) M.update(C.model) # 设置随机数种子 random.seed(0) np.random.seed(0) torch.manual_seed(0) # 设备配置 device_name = "cpu" os.environ["CUDA_VISIBLE_DEVICES"] = config['devices'] if torch.cuda.is_available(): device_name = "cuda" torch.backends.cudnn.deterministic = True torch.cuda.manual_seed(0) print("Let's use", torch.cuda.device_count(), "GPU(s)!") else: print("CUDA is not available") device = torch.device(device_name) # 数据加载 kwargs = { "collate_fn": collate, "num_workers": C.io.num_workers if os.name != "nt" else 0, "pin_memory": True, } train_loader = torch.utils.data.DataLoader( WireframeDataset(config['datadir'], dataset_type="train"), shuffle=True, batch_size=M.batch_size, **kwargs, ) val_loader = torch.utils.data.DataLoader( WireframeDataset(config['datadir'], dataset_type="val"), shuffle=False, batch_size=M.batch_size_eval, **kwargs, ) model = linercnn_resnet50_fpn().to(device) # 加载预训练权重 try: # 加载模型权重 checkpoint = torch.load(config['pretrained_model'], map_location=device) # 根据实际的检查点结构选择加载方式 if 'model_state_dict' in checkpoint: # 如果是完整的检查点 model.load_state_dict(checkpoint['model_state_dict']) elif 'state_dict' in checkpoint: # 如果是只有状态字典的检查点 model.load_state_dict(checkpoint['state_dict']) else: # 直接加载权重字典 model.load_state_dict(checkpoint) print("Successfully loaded pre-trained model weights.") except Exception as e: print(f"Error loading model weights: {e}") # 优化器配置 if C.optim.name == "Adam": optim = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=C.optim.lr, weight_decay=C.optim.weight_decay, amsgrad=C.optim.amsgrad, ) elif C.optim.name == "SGD": optim = torch.optim.SGD( filter(lambda p: p.requires_grad, model.parameters()), lr=C.optim.lr, weight_decay=C.optim.weight_decay, momentum=C.optim.momentum, ) else: raise NotImplementedError # 输出目录 outdir = osp.join( osp.expanduser(C.io.logdir), f"{datetime.datetime.now().strftime('%y%m%d-%H%M%S')}-{config['identifier']}" ) os.makedirs(outdir, exist_ok=True) try: trainer = lcnn.trainer.Trainer( device=device, model=model, optimizer=optim, train_loader=train_loader, val_loader=val_loader, out=outdir, ) print("Starting training...") trainer.train() print("Training completed.") except BaseException: if len(glob.glob(f"{outdir}/viz/*")) <= 1: shutil.rmtree(outdir) raise if __name__ == "__main__": main() ''' # 2025/2/9 import os from typing import Optional, Any import cv2 import numpy as np import torch from models.config.config_tool import read_yaml from models.line_detect.dataset_LD import WirePointDataset from tools import utils from torch.utils.tensorboard import SummaryWriter import matplotlib.pyplot as plt import matplotlib as mpl from skimage import io from models.line_detect.line_rcnn import linercnn_resnet50_fpn from torchvision.utils import draw_bounding_boxes from models.wirenet.postprocess import postprocess from torchvision import transforms from collections import OrderedDict device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def _loss(losses): total_loss = 0 for i in losses.keys(): if i != "loss_wirepoint": total_loss += losses[i] else: loss_labels = losses[i]["losses"] loss_labels_k = list(loss_labels[0].keys()) for j, name in enumerate(loss_labels_k): loss = loss_labels[0][name].mean() total_loss += loss return total_loss cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def c(x): return sm.to_rgba(x) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) def show_line(img, pred, epoch, writer): im = img.permute(1, 2, 0) writer.add_image("ori", im, epoch, dataformats="HWC") boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"], colors="yellow", width=1) writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} H = pred[1]['wires'] lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) for i, t in enumerate([0.8]): plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) for (a, b), s in zip(nlines, nscores): if s < t: continue plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) plt.scatter(a[1], a[0], **PLTOPTS) plt.scatter(b[1], b[0], **PLTOPTS) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im) plt.tight_layout() fig = plt.gcf() fig.canvas.draw() image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( fig.canvas.get_width_height()[::-1] + (3,)) plt.close() img2 = transforms.ToTensor()(image_from_plot) writer.add_image("output", img2, epoch) if __name__ == '__main__': cfg = r'D:\python\PycharmProjects\lcnn-master\lcnn_\MultiVisionModels\config\wireframe.yaml' cfg = read_yaml(cfg) print(f'cfg:{cfg}') print(cfg['model']['n_dyn_negl']) # net = WirepointPredictor() dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train') train_sampler = torch.utils.data.RandomSampler(dataset_train) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=1, drop_last=True) train_collate_fn = utils.collate_fn_wirepoint data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_sampler=train_batch_sampler, num_workers=0, collate_fn=train_collate_fn ) dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val') val_sampler = torch.utils.data.RandomSampler(dataset_val) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=1, drop_last=True) val_collate_fn = utils.collate_fn_wirepoint data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn ) model = linercnn_resnet50_fpn().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr']) writer = SummaryWriter(cfg['io']['logdir']) def move_to_device(data, device): if isinstance(data, (list, tuple)): return type(data)(move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 # def writer_loss(writer, losses, epoch): # try: # for key, value in losses.items(): # if key == 'loss_wirepoint': # for subdict in losses['loss_wirepoint']['losses']: # for subkey, subvalue in subdict.items(): # writer.add_scalar(f'loss_wirepoint/{subkey}', # subvalue.item() if hasattr(subvalue, 'item') else subvalue, # epoch) # elif isinstance(value, torch.Tensor): # writer.add_scalar(key, value.item(), epoch) # except Exception as e: # print(f"TensorBoard logging error: {e}") def writer_loss(writer, losses, epoch): try: for key, value in losses.items(): if key == 'loss_wirepoint': for subdict in losses['loss_wirepoint']['losses']: for subkey, subvalue in subdict.items(): writer.add_scalar(f'loss/{subkey}', subvalue.item() if hasattr(subvalue, 'item') else subvalue, epoch) elif isinstance(value, torch.Tensor): writer.add_scalar(f'loss/{key}', value.item(), epoch) except Exception as e: print(f"TensorBoard logging error: {e}") for epoch in range(cfg['optim']['max_epoch']): print(f"epoch:{epoch}") model.train() for imgs, targets in data_loader_train: losses = model(move_to_device(imgs, device), move_to_device(targets, device)) # print(losses) loss = _loss(losses) optimizer.zero_grad() loss.backward() optimizer.step() writer_loss(writer, losses, epoch) model.eval() with torch.no_grad(): for batch_idx, (imgs, targets) in enumerate(data_loader_val): pred = model(move_to_device(imgs, device)) if batch_idx == 0: show_line(imgs[0], pred, epoch, writer) break