train——line_rcnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # 2025/2/9
  2. import os
  3. import numpy as np
  4. import torch
  5. from models.config.config_tool import read_yaml
  6. from models.line_detect.dataset_LD import WirePointDataset
  7. from tools import utils
  8. from torch.utils.tensorboard import SummaryWriter
  9. import matplotlib as mpl
  10. from models.line_detect.line_net import linenet_resnet50_fpn
  11. from torchvision.utils import draw_bounding_boxes
  12. from models.wirenet.postprocess import postprocess
  13. from torchvision import transforms
  14. from PIL import Image
  15. from models.line_detect.postprocess import box_line_, show_
  16. import matplotlib.pyplot as plt
  17. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  18. def _loss(losses):
  19. total_loss = 0
  20. for i in losses.keys():
  21. if i != "loss_wirepoint":
  22. total_loss += losses[i]
  23. else:
  24. loss_labels = losses[i]["losses"]
  25. loss_labels_k = list(loss_labels[0].keys())
  26. for j, name in enumerate(loss_labels_k):
  27. loss = loss_labels[0][name].mean()
  28. total_loss += loss
  29. return total_loss
  30. cmap = plt.get_cmap("jet")
  31. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  32. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  33. sm.set_array([])
  34. def c(x):
  35. return sm.to_rgba(x)
  36. def imshow(im):
  37. plt.close()
  38. plt.tight_layout()
  39. plt.imshow(im)
  40. plt.colorbar(sm, fraction=0.046)
  41. plt.xlim([0, im.shape[0]])
  42. plt.ylim([im.shape[0], 0])
  43. def show_line(img, pred, epoch, writer):
  44. im = img.permute(1, 2, 0)
  45. writer.add_image("ori", im, epoch, dataformats="HWC")
  46. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  47. colors="yellow", width=1)
  48. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  49. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  50. # print(f'pred[1]:{pred[1]}')
  51. H = pred[-1]['wires']
  52. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  53. scores = H["score"][0].cpu().numpy()
  54. for i in range(1, len(lines)):
  55. if (lines[i] == lines[0]).all():
  56. lines = lines[:i]
  57. scores = scores[:i]
  58. break
  59. # postprocess lines to remove overlapped lines
  60. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  61. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  62. for i, t in enumerate([0.85]):
  63. plt.gca().set_axis_off()
  64. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  65. plt.margins(0, 0)
  66. for (a, b), s in zip(nlines, nscores):
  67. if s < t:
  68. continue
  69. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  70. plt.scatter(a[1], a[0], **PLTOPTS)
  71. plt.scatter(b[1], b[0], **PLTOPTS)
  72. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  73. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  74. plt.imshow(im)
  75. plt.tight_layout()
  76. fig = plt.gcf()
  77. fig.canvas.draw()
  78. width, height = fig.get_size_inches() * fig.get_dpi() # 获取图像尺寸
  79. tmp_img = fig.canvas.tostring_argb()
  80. tmp_img_np = np.frombuffer(tmp_img, dtype=np.uint8)
  81. tmp_img_np = tmp_img_np.reshape(int(height), int(width), 4)
  82. img_rgb = tmp_img_np[:, :, 1:] # 提取RGB部分,忽略Alpha通道
  83. # image_from_plot = np.frombuffer(tmp_img[:,:,1:], dtype=np.uint8).reshape(
  84. # fig.canvas.get_width_height()[::-1] + (3,))
  85. # plt.close()
  86. img2 = transforms.ToTensor()(img_rgb)
  87. writer.add_image("z-output", img2, epoch)
  88. def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
  89. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  90. if current_loss < best_loss:
  91. checkpoint = {
  92. 'epoch': epoch,
  93. 'model_state_dict': model.state_dict(),
  94. 'loss': current_loss
  95. }
  96. if optimizer is not None:
  97. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  98. torch.save(checkpoint, save_path)
  99. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  100. return current_loss
  101. return best_loss
  102. def save_latest_model(model, save_path, epoch, optimizer=None):
  103. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  104. checkpoint = {
  105. 'epoch': epoch,
  106. 'model_state_dict': model.state_dict(),
  107. }
  108. if optimizer is not None:
  109. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  110. torch.save(checkpoint, save_path)
  111. def load_best_model(model, optimizer, save_path, device):
  112. if os.path.exists(save_path):
  113. checkpoint = torch.load(save_path, map_location=device)
  114. model.load_state_dict(checkpoint['model_state_dict'])
  115. if optimizer is not None:
  116. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  117. epoch = checkpoint['epoch']
  118. loss = checkpoint['loss']
  119. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  120. else:
  121. print(f"No saved model found at {save_path}")
  122. return model, optimizer
  123. def predict(self, img, show_boxes=True, show_keypoint=True, show_line=True, save=False, save_path=None):
  124. self.load_weight('weights/best.pt')
  125. self.__model.eval()
  126. if isinstance(img, str):
  127. img = Image.open(img).convert("RGB")
  128. # 预处理图像
  129. img_tensor = self.transforms(img)
  130. with torch.no_grad():
  131. predictions = self.__model([img_tensor])
  132. # 后处理预测结果
  133. boxes = predictions[0]['boxes'].cpu().numpy()
  134. keypoints = predictions[0]['keypoints'].cpu().numpy()
  135. # 可视化预测结果
  136. if show_boxes or show_keypoint or show_line or save:
  137. fig, ax = plt.subplots(figsize=(10, 10))
  138. ax.imshow(np.array(img))
  139. if show_boxes:
  140. for box in boxes:
  141. x0, y0, x1, y1 = box
  142. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
  143. for (a, b) in keypoints:
  144. if show_keypoint:
  145. ax.scatter(a[0], a[1], c='c', s=2)
  146. ax.scatter(b[0], b[1], c='c', s=2)
  147. if show_line:
  148. ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
  149. if show_boxes or show_keypoint or show_line:
  150. plt.show()
  151. if save:
  152. fig.savefig(save_path)
  153. print(f"Prediction saved to {save_path}")
  154. plt.close(fig)
  155. if __name__ == '__main__':
  156. cfg = r'./config/wireframe.yaml'
  157. cfg = read_yaml(cfg)
  158. print(f'cfg:{cfg}')
  159. print(cfg['model']['n_dyn_negl'])
  160. # net = WirepointPredictor()
  161. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  162. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  163. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  164. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  165. train_collate_fn = utils.collate_fn_wirepoint
  166. data_loader_train = torch.utils.data.DataLoader(
  167. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  168. )
  169. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  170. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  171. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  172. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  173. val_collate_fn = utils.collate_fn_wirepoint
  174. data_loader_val = torch.utils.data.DataLoader(
  175. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  176. )
  177. model = linenet_resnet50_fpn().to(device)
  178. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  179. writer = SummaryWriter(cfg['io']['logdir'])
  180. # 加载权重
  181. save_path = 'logs/pth/best_model.pth'
  182. model, optimizer = load_best_model(model, optimizer, save_path, device)
  183. logdir_with_pth = os.path.join(cfg['io']['logdir'], 'pth')
  184. os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  185. latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
  186. best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
  187. global_step = 0
  188. def move_to_device(data, device):
  189. if isinstance(data, (list, tuple)):
  190. return type(data)(move_to_device(item, device) for item in data)
  191. elif isinstance(data, dict):
  192. return {key: move_to_device(value, device) for key, value in data.items()}
  193. elif isinstance(data, torch.Tensor):
  194. return data.to(device)
  195. else:
  196. return data # 对于非张量类型的数据不做任何改变
  197. def writer_loss(writer, losses, epoch):
  198. try:
  199. for key, value in losses.items():
  200. if key == 'loss_wirepoint':
  201. for subdict in losses['loss_wirepoint']['losses']:
  202. for subkey, subvalue in subdict.items():
  203. writer.add_scalar(f'loss/{subkey}',
  204. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  205. epoch)
  206. elif isinstance(value, torch.Tensor):
  207. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  208. except Exception as e:
  209. print(f"TensorBoard logging error: {e}")
  210. for epoch in range(cfg['optim']['max_epoch']):
  211. print(f"epoch:{epoch}")
  212. model.train()
  213. total_train_loss = 0.0
  214. for imgs, targets in data_loader_train:
  215. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  216. # print(losses)
  217. loss = _loss(losses)
  218. total_train_loss += loss.item()
  219. optimizer.zero_grad()
  220. loss.backward()
  221. optimizer.step()
  222. writer_loss(writer, losses, epoch)
  223. model.eval()
  224. with torch.no_grad():
  225. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  226. pred = model(move_to_device(imgs, device))
  227. # pred_ = box_line_(pred) # 将box与line对应
  228. # show_(imgs, pred_, epoch, writer)
  229. if batch_idx == 0:
  230. show_line(imgs[0], pred, epoch, writer)
  231. break
  232. avg_train_loss = total_train_loss / len(data_loader_train)
  233. writer.add_scalar('loss/train', avg_train_loss, epoch)
  234. best_loss = 10000
  235. save_latest_model(
  236. model,
  237. latest_model_path,
  238. epoch,
  239. optimizer
  240. )
  241. best_loss = save_best_model(
  242. model,
  243. best_model_path,
  244. epoch,
  245. avg_train_loss,
  246. best_loss,
  247. optimizer
  248. )