train——line_rcnn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # 2025/2/9
  2. import os
  3. from typing import Optional, Any
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from models.config.config_tool import read_yaml
  8. from models.line_detect.dataset_LD import WirePointDataset
  9. from tools import utils
  10. from torch.utils.tensorboard import SummaryWriter
  11. import matplotlib.pyplot as plt
  12. import matplotlib as mpl
  13. from skimage import io
  14. from models.line_detect.line_net import linenet_resnet50_fpn
  15. from torchvision.utils import draw_bounding_boxes
  16. from models.wirenet.postprocess import postprocess
  17. from torchvision import transforms
  18. from collections import OrderedDict
  19. from PIL import Image
  20. from predict import box_line_, show_
  21. import matplotlib.pyplot as plt
  22. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  23. def _loss(losses):
  24. total_loss = 0
  25. for i in losses.keys():
  26. if i != "loss_wirepoint":
  27. total_loss += losses[i]
  28. else:
  29. loss_labels = losses[i]["losses"]
  30. loss_labels_k = list(loss_labels[0].keys())
  31. for j, name in enumerate(loss_labels_k):
  32. loss = loss_labels[0][name].mean()
  33. total_loss += loss
  34. return total_loss
  35. cmap = plt.get_cmap("jet")
  36. norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0)
  37. sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
  38. sm.set_array([])
  39. def c(x):
  40. return sm.to_rgba(x)
  41. def imshow(im):
  42. plt.close()
  43. plt.tight_layout()
  44. plt.imshow(im)
  45. plt.colorbar(sm, fraction=0.046)
  46. plt.xlim([0, im.shape[0]])
  47. plt.ylim([im.shape[0], 0])
  48. def show_line(img, pred, epoch, writer):
  49. im = img.permute(1, 2, 0)
  50. writer.add_image("ori", im, epoch, dataformats="HWC")
  51. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  52. colors="yellow", width=1)
  53. writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC")
  54. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  55. # print(f'pred[1]:{pred[1]}')
  56. H = pred[-1]['wires']
  57. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  58. scores = H["score"][0].cpu().numpy()
  59. for i in range(1, len(lines)):
  60. if (lines[i] == lines[0]).all():
  61. lines = lines[:i]
  62. scores = scores[:i]
  63. break
  64. # postprocess lines to remove overlapped lines
  65. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  66. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  67. for i, t in enumerate([0.85]):
  68. plt.gca().set_axis_off()
  69. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  70. plt.margins(0, 0)
  71. for (a, b), s in zip(nlines, nscores):
  72. if s < t:
  73. continue
  74. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  75. plt.scatter(a[1], a[0], **PLTOPTS)
  76. plt.scatter(b[1], b[0], **PLTOPTS)
  77. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  78. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  79. plt.imshow(im)
  80. plt.tight_layout()
  81. fig = plt.gcf()
  82. fig.canvas.draw()
  83. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  84. fig.canvas.get_width_height()[::-1] + (3,))
  85. plt.close()
  86. img2 = transforms.ToTensor()(image_from_plot)
  87. writer.add_image("output", img2, epoch)
  88. def save_best_model(model, save_path, epoch, current_loss, best_loss, optimizer=None):
  89. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  90. if current_loss < best_loss:
  91. checkpoint = {
  92. 'epoch': epoch,
  93. 'model_state_dict': model.state_dict(),
  94. 'loss': current_loss
  95. }
  96. if optimizer is not None:
  97. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  98. torch.save(checkpoint, save_path)
  99. print(f"Saved best model at epoch {epoch} with loss {current_loss:.4f}")
  100. return current_loss
  101. return best_loss
  102. def save_latest_model(model, save_path, epoch, optimizer=None):
  103. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  104. checkpoint = {
  105. 'epoch': epoch,
  106. 'model_state_dict': model.state_dict(),
  107. }
  108. if optimizer is not None:
  109. checkpoint['optimizer_state_dict'] = optimizer.state_dict()
  110. torch.save(checkpoint, save_path)
  111. def load_best_model(model, optimizer, save_path, device):
  112. if os.path.exists(save_path):
  113. checkpoint = torch.load(save_path, map_location=device)
  114. model.load_state_dict(checkpoint['model_state_dict'])
  115. if optimizer is not None:
  116. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  117. epoch = checkpoint['epoch']
  118. loss = checkpoint['loss']
  119. print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}")
  120. else:
  121. print(f"No saved model found at {save_path}")
  122. return model, optimizer
  123. def predict(self, img, show_boxes=True, show_keypoint=True, show_line=True, save=False, save_path=None):
  124. self.load_weight('weights/best.pt')
  125. self.__model.eval()
  126. if isinstance(img, str):
  127. img = Image.open(img).convert("RGB")
  128. # 预处理图像
  129. img_tensor = self.transforms(img)
  130. with torch.no_grad():
  131. predictions = self.__model([img_tensor])
  132. # 后处理预测结果
  133. boxes = predictions[0]['boxes'].cpu().numpy()
  134. keypoints = predictions[0]['keypoints'].cpu().numpy()
  135. # 可视化预测结果
  136. if show_boxes or show_keypoint or show_line or save:
  137. fig, ax = plt.subplots(figsize=(10, 10))
  138. ax.imshow(np.array(img))
  139. if show_boxes:
  140. for box in boxes:
  141. x0, y0, x1, y1 = box
  142. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor='yellow', linewidth=1))
  143. for (a, b) in keypoints:
  144. if show_keypoint:
  145. ax.scatter(a[0], a[1], c='c', s=2)
  146. ax.scatter(b[0], b[1], c='c', s=2)
  147. if show_line:
  148. ax.plot([a[0], b[0]], [a[1], b[1]], c='red', linewidth=1)
  149. if show_boxes or show_keypoint or show_line:
  150. plt.show()
  151. if save:
  152. fig.savefig(save_path)
  153. print(f"Prediction saved to {save_path}")
  154. plt.close(fig)
  155. if __name__ == '__main__':
  156. cfg = r'./config/wireframe.yaml'
  157. cfg = read_yaml(cfg)
  158. print(f'cfg:{cfg}')
  159. print(cfg['model']['n_dyn_negl'])
  160. # net = WirepointPredictor()
  161. dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train')
  162. train_sampler = torch.utils.data.RandomSampler(dataset_train)
  163. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  164. train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, batch_size=2, drop_last=True)
  165. train_collate_fn = utils.collate_fn_wirepoint
  166. data_loader_train = torch.utils.data.DataLoader(
  167. dataset_train, batch_sampler=train_batch_sampler, num_workers=8, collate_fn=train_collate_fn
  168. )
  169. dataset_val = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='val')
  170. val_sampler = torch.utils.data.RandomSampler(dataset_val)
  171. # test_sampler = torch.utils.data.SequentialSampler(dataset_test)
  172. val_batch_sampler = torch.utils.data.BatchSampler(val_sampler, batch_size=2, drop_last=True)
  173. val_collate_fn = utils.collate_fn_wirepoint
  174. data_loader_val = torch.utils.data.DataLoader(
  175. dataset_val, batch_sampler=val_batch_sampler, num_workers=8, collate_fn=val_collate_fn
  176. )
  177. model = linenet_resnet50_fpn().to(device)
  178. optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr'])
  179. writer = SummaryWriter(cfg['io']['logdir'])
  180. # 加载权重
  181. save_path = 'logs/pth/best_model.pth'
  182. model, optimizer = load_best_model(model, optimizer, save_path, device)
  183. logdir_with_pth = os.path.join(cfg['io']['logdir'], 'pth')
  184. os.makedirs(logdir_with_pth, exist_ok=True) # 创建目录(如果不存在)
  185. latest_model_path = os.path.join(logdir_with_pth, 'latest_model.pth')
  186. best_model_path = os.path.join(logdir_with_pth, 'best_model.pth')
  187. global_step = 0
  188. def move_to_device(data, device):
  189. if isinstance(data, (list, tuple)):
  190. return type(data)(move_to_device(item, device) for item in data)
  191. elif isinstance(data, dict):
  192. return {key: move_to_device(value, device) for key, value in data.items()}
  193. elif isinstance(data, torch.Tensor):
  194. return data.to(device)
  195. else:
  196. return data # 对于非张量类型的数据不做任何改变
  197. def writer_loss(writer, losses, epoch):
  198. try:
  199. for key, value in losses.items():
  200. if key == 'loss_wirepoint':
  201. for subdict in losses['loss_wirepoint']['losses']:
  202. for subkey, subvalue in subdict.items():
  203. writer.add_scalar(f'loss/{subkey}',
  204. subvalue.item() if hasattr(subvalue, 'item') else subvalue,
  205. epoch)
  206. elif isinstance(value, torch.Tensor):
  207. writer.add_scalar(f'loss/{key}', value.item(), epoch)
  208. except Exception as e:
  209. print(f"TensorBoard logging error: {e}")
  210. for epoch in range(cfg['optim']['max_epoch']):
  211. print(f"epoch:{epoch}")
  212. model.train()
  213. total_train_loss = 0.0
  214. for imgs, targets in data_loader_train:
  215. losses = model(move_to_device(imgs, device), move_to_device(targets, device))
  216. # print(losses)
  217. loss = _loss(losses)
  218. total_train_loss += loss.item()
  219. optimizer.zero_grad()
  220. loss.backward()
  221. optimizer.step()
  222. writer_loss(writer, losses, epoch)
  223. model.eval()
  224. with torch.no_grad():
  225. for batch_idx, (imgs, targets) in enumerate(data_loader_val):
  226. pred = model(move_to_device(imgs, device))
  227. pred_ = box_line_(pred) # 将box与line对应
  228. show_(imgs, pred_, epoch, writer)
  229. if batch_idx == 0:
  230. show_line(imgs[0], pred, epoch, writer)
  231. break
  232. avg_train_loss = total_train_loss / len(data_loader_train)
  233. writer.add_scalar('loss/train', avg_train_loss, epoch)
  234. best_loss = 10000
  235. save_latest_model(
  236. model,
  237. latest_model_path,
  238. epoch,
  239. optimizer
  240. )
  241. best_loss = save_best_model(
  242. model,
  243. best_model_path,
  244. epoch,
  245. avg_train_loss,
  246. best_loss,
  247. optimizer
  248. )