123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- # 2025/2/9
- import os
- from typing import Optional, Any
- import cv2
- import numpy as np
- import torch
- from models.config.config_tool import read_yaml
- from models.line_detect.dataset_LD import WirePointDataset
- from tools import utils
- from torch.utils.tensorboard import SummaryWriter
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- from skimage import io
- from models.line_detect.line_net import linenet_resnet50_fpn
- from torchvision.utils import draw_bounding_boxes
- from models.wirenet.postprocess import postprocess
- from torchvision import transforms
- from collections import OrderedDict
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def _loss(losses):
- total_loss = 0
- for i in losses.keys():
- if i != "loss_wirepoint":
- total_loss += losses[i]
- else:
- loss_labels = losses[i]["losses"]
- loss_labels_k = list(loss_labels[0].keys())
- for j, name in enumerate(loss_labels_k):
- loss = loss_labels[0][name].mean()
- total_loss += loss
- return total_loss
- cmap = plt.get_cmap("jet")
- norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
- sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
- sm.set_array([])
- def c(x):
- return sm.to_rgba(x)
- def imshow(im):
- plt.close()
- plt.tight_layout()
- plt.imshow(im)
- plt.colorbar(sm, fraction=0.046)
- plt.xlim([0, im.shape[0]])
- plt.ylim([im.shape[0], 0])
- def show_line(img, pred, epoch, writer):
- im = img.permute(1, 2, 0)
- writer.add_image("ori", im, epoch, dataformats="HWC")
- boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
- colors="yellow", width=1)
- writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
- PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
- H = pred[1]['wires']
- lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
- scores = H["score"][0].cpu().numpy()
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- # postprocess lines to remove overlapped lines
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
- for i, t in enumerate([0.8]):
- plt.gca().set_axis_off()
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
- plt.margins(0, 0)
- for (a, b), s in zip(nlines, nscores):
- if s < t:
- continue
- plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
- plt.scatter(a[1], a[0], **PLTOPTS)
- plt.scatter(b[1], b[0], **PLTOPTS)
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.imshow(im)
- plt.tight_layout()
- fig = plt.gcf()
- fig.canvas.draw()
- image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
- fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- img2 = transforms.ToTensor()(image_from_plot)
- writer.add_image("output", img2, epoch)
- if __name__ == '__main__':
- cfg = r'./config/wireframe.yaml'
- cfg = read_yaml(cfg)
- print(f'cfg:{cfg}')
- print(cfg['model']['n_dyn_negl'])
- # net = WirepointPredictor()
- dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
- train_sampler = torch.utils.data.RandomSampler(dataset_train)
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
- train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
- train_collate_fn = utils.collate_fn_wirepoint
- data_loader_train = torch.utils.data.DataLoader(
- dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
- )
- dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
- val_sampler = torch.utils.data.RandomSampler(dataset_val)
- # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
- val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
- val_collate_fn = utils.collate_fn_wirepoint
- data_loader_val = torch.utils.data.DataLoader(
- dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
- )
- model = linenet_resnet50_fpn().to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
- writer = SummaryWriter(cfg['io']['logdir'])
- def move_to_device(data, device):
- if isinstance(data, (list, tuple)):
- return type(data)(move_to_device(item, device) for item in data)
- elif isinstance(data, dict):
- return {key: move_to_device(value, device) for key, value in data.items()}
- elif isinstance(data, torch.Tensor):
- return data.to(device)
- else:
- return data # 对于非张量类型的数据不做任何改变
- # def writer_loss(writer, losses, epoch):
- # try:
- # for key, value in losses.items():
- # if key == 'loss_wirepoint':
- # for subdict in losses['loss_wirepoint']['losses']:
- # for subkey, subvalue in subdict.items():
- # writer.add_scalar(f'loss_wirepoint/{subkey}',
- # subvalue.item() if hasattr(subvalue, 'item') else subvalue,
- # epoch)
- # elif isinstance(value, torch.Tensor):
- # writer.add_scalar(key, value.item(), epoch)
- # except Exception as e:
- # print(f"TensorBoard logging error: {e}")
- def writer_loss(writer, losses, epoch):
- try:
- for key, value in losses.items():
- if key == 'loss_wirepoint':
- for subdict in losses['loss_wirepoint']['losses']:
- for subkey, subvalue in subdict.items():
- writer.add_scalar(f'loss/{subkey}',
- subvalue.item() if hasattr(subvalue, 'item') else subvalue,
- epoch)
- elif isinstance(value, torch.Tensor):
- writer.add_scalar(f'loss/{key}', value.item(), epoch)
- except Exception as e:
- print(f"TensorBoard logging error: {e}")
- for epoch in range(cfg['optim']['max_epoch']):
- print(f"epoch:{epoch}")
- model.train()
- for imgs, targets in data_loader_train:
- losses = model(move_to_device(imgs, device), move_to_device(targets, device))
- # print(losses)
- loss = _loss(losses)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- writer_loss(writer, losses, epoch)
- model.eval()
- with torch.no_grad():
- for batch_idx, (imgs, targets) in enumerate(data_loader_val):
- pred = model(move_to_device(imgs, device))
- if batch_idx == 0:
- show_line(imgs[0], pred, epoch, writer)
- break
|