# 2025/2/9 import os import numpy as np import torch from models.config.config_tool import read_yaml from models.line_detect.dataset_LD import WirePointDataset from tools import utils from torch.utils.tensorboard import SummaryWriter import matplotlib as mpl from models.line_detect.line_net import linenet_resnet50_fpn from torchvision.utils import draw_bounding_boxes from models.wirenet.postprocess import postprocess from torchvision import transforms from PIL import Image from models.line_detect.postprocess import box_line_, show_ import matplotlib.pyplot as plt device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def _loss(losses): total_loss = 0 for i in losses.keys(): if i != "loss_wirepoint": total_loss += losses[i] else: loss_labels = losses[i]["losses"] loss_labels_k = list(loss_labels[0].keys()) for j, name in enumerate(loss_labels_k): loss = loss_labels[0][name].mean() total_loss += loss return total_loss cmap = plt.get_cmap("jet") norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) def c(x): return sm.to_rgba(x) def imshow(im): plt.close() plt.tight_layout() plt.imshow(im) plt.colorbar(sm, fraction=0.046) plt.xlim([0, im.shape[0]]) plt.ylim([im.shape[0], 0]) def show_line(img, pred, epoch, writer): im = img.permute(1, 2, 0) writer.add_image("ori", im, epoch, dataformats="HWC") boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"], colors="yellow", width=1) writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} # print(f'pred[1]:{pred[1]}') H = pred[-1]['wires'] lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # postprocess lines to remove overlapped lines diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) for i, t in enumerate([0.85]): plt.gca().set_axis_off() plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) for (a, b), s in zip(nlines, nscores): if s < t: continue plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) plt.scatter(a[1], a[0], **PLTOPTS) plt.scatter(b[1], b[0], **PLTOPTS) plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.imshow(im) plt.tight_layout() fig = plt.gcf() fig.canvas.draw() width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸 tmp_img = fig.canvas.tostring_argb() tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8) tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4) img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道 # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape( # fig.canvas.get_width_height()[::-1] + (3,)) # plt.close() img2 = transforms.ToTensor()(img_rgb) writer.add_image("z-output", img2, epoch) def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) if current_loss < best_loss: checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'loss': current_loss } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, save_path) print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}") return current_loss return best_loss def save_latest_model(model, save_path, epoch, optimizer=None): os.makedirs(os.path.dirname(save_path), exist_ok=True) checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), } if optimizer is not None: checkpoint['optimizer_state_dict'] = optimizer.state_dict() torch.save(checkpoint, save_path) def load_best_model(model, optimizer, save_path, device): if os.path.exists(save_path): checkpoint = torch.load(save_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") else: print(f"No saved model found at {save_path}") return model, optimizer def predict(self, img, show_boxes=True, show_keypoint=True, show_line=True, save=False, save_path=None): self.load_weight('weights/best.pt') self.__model.eval() if isinstance(img, str): img = Image.open(img).convert("RGB") # 预处理图像 img_tensor = self.transforms(img) with torch.no_grad(): predictions = self.__model([img_tensor]) # 后处理预测结果 boxes = predictions[0]['boxes'].cpu().numpy() keypoints = predictions[0]['keypoints'].cpu().numpy() # 可视化预测结果 if show_boxes or show_keypoint or show_line or save: fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(img)) if show_boxes: for box in boxes: x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1)) for (a, b) in keypoints: if show_keypoint: ax.scatter(a[0], a[1], c='c', s=2) ax.scatter(b[0], b[1], c='c', s=2) if show_line: ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1) if show_boxes or show_keypoint or show_line: plt.show() if save: fig.savefig(save_path) print(f"Prediction saved to {save_path}") plt.close(fig) if __name__ == '__main__': cfg = r'./config/wireframe.yaml' cfg = read_yaml(cfg) print(f'cfg:{cfg}') print(cfg['model']['n_dyn_negl']) # net = WirepointPredictor() dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train') train_sampler = torch.utils.data.RandomSampler(dataset_train) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True) train_collate_fn = utils.collate_fn_wirepoint data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn ) dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val') val_sampler = torch.utils.data.RandomSampler(dataset_val) # test_sampler = torch.utils.data.SequentialSampler(dataset_test) val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True) val_collate_fn = utils.collate_fn_wirepoint data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn ) model = linenet_resnet50_fpn().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr']) writer = SummaryWriter(cfg['io']['logdir']) # 加载权重 save_path = 'logs/pth/best_model.pth' model, optimizer = load_best_model(model, optimizer, save_path, device) logdir_with_pth = os.path.join(cfg['io']['logdir'], 'pth') os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在) latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth') best_model_path = os.path.join(logdir_with_pth, 'best_model.pth') global_step = 0 def move_to_device(data, device): if isinstance(data, (list, tuple)): return type(data)(move_to_device(item, device) for item in data) elif isinstance(data, dict): return {key: move_to_device(value, device) for key, value in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data # 对于非张量类型的数据不做任何改变 def writer_loss(writer, losses, epoch): try: for key, value in losses.items(): if key == 'loss_wirepoint': for subdict in losses['loss_wirepoint']['losses']: for subkey, subvalue in subdict.items(): writer.add_scalar(f'loss/{subkey}', subvalue.item() if hasattr(subvalue, 'item') else subvalue, epoch) elif isinstance(value, torch.Tensor): writer.add_scalar(f'loss/{key}', value.item(), epoch) except Exception as e: print(f"TensorBoard logging error: {e}") for epoch in range(cfg['optim']['max_epoch']): print(f"epoch:{epoch}") model.train() total_train_loss = 0.0 for imgs, targets in data_loader_train: losses = model(move_to_device(imgs, device), move_to_device(targets, device)) # print(losses) loss = _loss(losses) total_train_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() writer_loss(writer, losses, epoch) model.eval() with torch.no_grad(): for batch_idx, (imgs, targets) in enumerate(data_loader_val): pred = model(move_to_device(imgs, device)) # pred_ = box_line_(pred) # 将box与line对应 # show_(imgs, pred_, epoch, writer) if batch_idx == 0: show_line(imgs[0], pred, epoch, writer) break avg_train_loss = total_train_loss / len(data_loader_train) writer.add_scalar('loss/train', avg_train_loss, epoch) best_loss = 10000 save_latest_model( model, latest_model_path, epoch, optimizer ) best_loss = save_best_model( model, best_model_path, epoch, avg_train_loss, best_loss, optimizer )