train——line_rcnn.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # 2025/2/9
  2. import os
  3. from typing import Optional, Any
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from models.config.config_tool import read_yaml
  8. from models.line_detect.dataset_LD import WirePointDataset
  9. from tools import utils
  10. from torch.utils.tensorboard import SummaryWriter
  11. import matplotlib.pyplot as plt
  12. import matplotlib as mpl
  13. from skimage import io
  14. from models.line_detect.line_net import linenet_resnet50_fpn
  15. from torchvision.utils import draw_bounding_boxes
  16. from models.wirenet.postprocess import postprocess
  17. from torchvision import transforms
  18. from collections import OrderedDict
  19. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  20. def _loss(losses):
  21. total_loss = 0
  22. for i in losses.keys():
  23. if i != "loss_wirepoint":
  24. total_loss += losses[i]
  25. else:
  26. loss_labels = losses[i]["losses"]
  27. loss_labels_k = list(loss_labels[0].keys())
  28. for j, name in enumerate(loss_labels_k):
  29. loss = loss_labels[0][name].mean()
  30. total_loss += loss
  31. return total_loss
  32. cmap = plt.get_cmap("jet")
  33. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  34. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  35. sm.set_array([])
  36. def c(x):
  37. return sm.to_rgba(x)
  38. def imshow(im):
  39. plt.close()
  40. plt.tight_layout()
  41. plt.imshow(im)
  42. plt.colorbar(sm, fraction=0.046)
  43. plt.xlim([0, im.shape[0]])
  44. plt.ylim([im.shape[0], 0])
  45. def show_line(img, pred, epoch, writer):
  46. im = img.permute(1, 2, 0)
  47. writer.add_image("ori", im, epoch, dataformats="HWC")
  48. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  49. colors="yellow", width=1)
  50. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  51. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  52. H = pred[1]['wires']
  53. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  54. scores = H["score"][0].cpu().numpy()
  55. for i in range(1, len(lines)):
  56. if (lines[i] == lines[0]).all():
  57. lines = lines[:i]
  58. scores = scores[:i]
  59. break
  60. # postprocess lines to remove overlapped lines
  61. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  62. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  63. for i, t in enumerate([0.8]):
  64. plt.gca().set_axis_off()
  65. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  66. plt.margins(0, 0)
  67. for (a, b), s in zip(nlines, nscores):
  68. if s < t:
  69. continue
  70. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  71. plt.scatter(a[1], a[0], **PLTOPTS)
  72. plt.scatter(b[1], b[0], **PLTOPTS)
  73. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  74. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  75. plt.imshow(im)
  76. plt.tight_layout()
  77. fig = plt.gcf()
  78. fig.canvas.draw()
  79. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  80. fig.canvas.get_width_height()[::-1] + (3,))
  81. plt.close()
  82. img2 = transforms.ToTensor()(image_from_plot)
  83. writer.add_image("output", img2, epoch)
  84. if __name__ == '__main__':
  85. cfg = r'./config/wireframe.yaml'
  86. cfg = read_yaml(cfg)
  87. print(f'cfg:{cfg}')
  88. print(cfg['model']['n_dyn_negl'])
  89. # net = WirepointPredictor()
  90. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  91. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  92. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  93. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  94. train_collate_fn = utils.collate_fn_wirepoint
  95. data_loader_train = torch.utils.data.DataLoader(
  96. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  97. )
  98. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  99. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  100. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  101. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  102. val_collate_fn = utils.collate_fn_wirepoint
  103. data_loader_val = torch.utils.data.DataLoader(
  104. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  105. )
  106. model = linenet_resnet50_fpn().to(device)
  107. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  108. writer = SummaryWriter(cfg['io']['logdir'])
  109. def move_to_device(data, device):
  110. if isinstance(data, (list, tuple)):
  111. return type(data)(move_to_device(item, device) for item in data)
  112. elif isinstance(data, dict):
  113. return {key: move_to_device(value, device) for key, value in data.items()}
  114. elif isinstance(data, torch.Tensor):
  115. return data.to(device)
  116. else:
  117. return data # 对于非张量类型的数据不做任何改变
  118. # def writer_loss(writer, losses, epoch):
  119. # try:
  120. # for key, value in losses.items():
  121. # if key == 'loss_wirepoint':
  122. # for subdict in losses['loss_wirepoint']['losses']:
  123. # for subkey, subvalue in subdict.items():
  124. # writer.add_scalar(f'loss_wirepoint/{subkey}',
  125. # subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  126. # epoch)
  127. # elif isinstance(value, torch.Tensor):
  128. # writer.add_scalar(key, value.item(), epoch)
  129. # except Exception as e:
  130. # print(f"TensorBoard logging error: {e}")
  131. def writer_loss(writer, losses, epoch):
  132. try:
  133. for key, value in losses.items():
  134. if key == 'loss_wirepoint':
  135. for subdict in losses['loss_wirepoint']['losses']:
  136. for subkey, subvalue in subdict.items():
  137. writer.add_scalar(f'loss/{subkey}',
  138. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  139. epoch)
  140. elif isinstance(value, torch.Tensor):
  141. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  142. except Exception as e:
  143. print(f"TensorBoard logging error: {e}")
  144. for epoch in range(cfg['optim']['max_epoch']):
  145. print(f"epoch:{epoch}")
  146. model.train()
  147. for imgs, targets in data_loader_train:
  148. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  149. # print(losses)
  150. loss = _loss(losses)
  151. optimizer.zero_grad()
  152. loss.backward()
  153. optimizer.step()
  154. writer_loss(writer, losses, epoch)
  155. model.eval()
  156. with torch.no_grad():
  157. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  158. pred = model(move_to_device(imgs, device))
  159. if batch_idx == 0:
  160. show_line(imgs[0], pred, epoch, writer)
  161. break