postprocess.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. import os
  2. import time
  3. import torch
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. from torchvision import transforms
  7. from models.wirenet.postprocess import postprocess
  8. from datetime import datetime
  9. def box_line(pred):
  10. '''
  11. :param pred: 预测结果
  12. :return:
  13. box与line一一对应
  14. {'box': [0.0, 34.23157501220703, 151.70858764648438, 125.10173797607422], 'line': array([[ 1.9720564, 81.73457 ],
  15. [ 1.9933801, 41.730167 ]], dtype=float32)}
  16. '''
  17. box_line = [[] for _ in range((len(pred) - 1))]
  18. for idx, box_ in enumerate(pred[0:-1]):
  19. box = box_['boxes'] # 是一个tensor
  20. line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  21. score = pred[-1]['wires']['score'][idx]
  22. for i in box:
  23. aaa = {}
  24. aaa['box'] = i.tolist()
  25. aaa['line'] = []
  26. score_max = 0.0
  27. for j in range(len(line)):
  28. if (line[j][0][0] >= i[0] and line[j][1][0] >= i[0] and line[j][0][0] <= i[2] and
  29. line[j][1][0] <= i[2] and line[j][0][1] >= i[1] and line[j][1][1] >= i[1] and
  30. line[j][0][1] <= i[3] and line[j][1][1] <= i[3]):
  31. if score[j] > score_max:
  32. aaa['line'] = line[j]
  33. score_max = score[j]
  34. box_line[idx].append(aaa)
  35. def box_line_(pred):
  36. for idx, box_ in enumerate(pred[0:-1]):
  37. box = box_['boxes'] # 是一个tensor
  38. line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512
  39. score = pred[-1]['wires']['score'][idx]
  40. line_ = []
  41. score_ = []
  42. for i in box:
  43. score_max = 0.0
  44. tmp = [[0.0, 0.0], [0.0, 0.0]]
  45. for j in range(len(line)):
  46. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  47. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  48. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  49. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  50. if score[j] > score_max:
  51. tmp = line[j]
  52. score_max = score[j]
  53. line_.append(tmp)
  54. score_.append(score_max)
  55. processed_list = torch.tensor(line_)
  56. pred[idx]['line'] = processed_list
  57. processed_s_list = torch.tensor(score_)
  58. pred[idx]['line_score'] = processed_s_list
  59. return pred
  60. # box与line匹配后画在一张图上,不设置阈值,直接画
  61. def show_(imgs, pred, epoch, writer):
  62. col = [
  63. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  64. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  65. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  66. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  67. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  68. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  69. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  70. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  71. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  72. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  73. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  74. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  75. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  76. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  77. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  78. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  79. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  80. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  81. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  82. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  83. ]
  84. # print(len(col))
  85. im = imgs[0].permute(1, 2, 0)
  86. boxes = pred[0]['boxes'].cpu().numpy()
  87. line = pred[0]['line'].cpu().numpy()
  88. # 可视化预测结
  89. fig, ax = plt.subplots(figsize=(10, 10))
  90. ax.imshow(np.array(im))
  91. for idx, box in enumerate(boxes):
  92. x0, y0, x1, y1 = box
  93. ax.add_patch(
  94. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  95. for idx, (a, b) in enumerate(line):
  96. ax.scatter(a[1], a[0], c='#871F78', s=2)
  97. ax.scatter(b[1], b[0], c='#871F78', s=2)
  98. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  99. # 将Matplotlib图像转换为Tensor
  100. fig.canvas.draw()
  101. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  102. fig.canvas.get_width_height()[::-1] + (3,))
  103. plt.close()
  104. img2 = transforms.ToTensor()(image_from_plot)
  105. writer.add_image("all", img2, epoch)
  106. # box与line匹配后画在一张图上,设置阈值
  107. def show_predict(imgs, pred, t_start):
  108. col = [
  109. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  110. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  111. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  112. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  113. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  114. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  115. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  116. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  117. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  118. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  119. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  120. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  121. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  122. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  123. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  124. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  125. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  126. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  127. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  128. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  129. ]
  130. print(len(col))
  131. im = imgs.permute(1, 2, 0) # 处理为 [512, 512, 3]
  132. boxes = pred[0]['boxes'].cpu().numpy()
  133. box_scores = pred[0]['scores'].cpu().numpy()
  134. lines = pred[0]['line'].cpu().numpy()
  135. line_scores = pred[0]['line_score'].cpu().numpy()
  136. # 可视化预测结
  137. fig, ax = plt.subplots(figsize=(10, 10))
  138. ax.imshow(np.array(im))
  139. idx = 0
  140. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  141. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  142. x0, y0, x1, y1 = box
  143. # 框中无线的跳过
  144. if np.array_equal(line, tmp):
  145. continue
  146. a, b = line
  147. if box_score >= 0.7 or line_score >= 0.9:
  148. ax.add_patch(
  149. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  150. ax.scatter(a[1], a[0], c='#871F78', s=10)
  151. ax.scatter(b[1], b[0], c='#871F78', s=10)
  152. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  153. idx = idx + 1
  154. t_end = time.time()
  155. print(f'predict used:{t_end - t_start}')
  156. plt.show()
  157. # 下面的都没有进行box与line的一一匹配
  158. # 只画线,设阈值
  159. def show_line(imgs, pred, t_start):
  160. im = imgs.permute(1, 2, 0)
  161. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  162. # print(pred[-1]['wires']['score'])
  163. line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
  164. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  165. line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
  166. print(f'lines num:{len(line)}')
  167. #
  168. # count = np.sum(line_score > 0.9)
  169. # print(f'draw line number:{count}')
  170. # 可视化预测结
  171. fig, ax = plt.subplots(figsize=(10, 10))
  172. ax.imshow(np.array(im))
  173. for idx, (a, b) in enumerate(line):
  174. # if line_score[idx] < 0.7:
  175. # continue
  176. ax.scatter(a[1], a[0], c='#871F78', s=2)
  177. ax.scatter(b[1], b[0], c='#871F78', s=2)
  178. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  179. t_end = time.time()
  180. print(f'show_line used:{t_end - t_start}')
  181. plt.show()
  182. # show_line优化
  183. def show_line_optimized(imgs, pred, t_start):
  184. im = imgs.permute(1, 2, 0).cpu().numpy()
  185. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  186. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  187. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  188. nlines, nscores = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  189. fig, ax = plt.subplots(figsize=(10, 10))
  190. ax.imshow(im)
  191. for i, t in enumerate([0.9]):
  192. for (a, b), s in zip(nlines, nscores):
  193. if s < t:
  194. continue
  195. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  196. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  197. t_end = time.time()
  198. print(f'show_line_optimized used:{t_end - t_start}')
  199. plt.show()
  200. # 只画框,设阈值
  201. def show_box(imgs, pred, t_start):
  202. col = [
  203. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  204. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  205. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  206. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  207. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  208. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  209. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  210. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  211. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  212. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  213. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  214. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  215. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  216. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  217. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  218. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  219. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  220. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  221. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  222. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  223. ]
  224. # print(len(col))
  225. im = imgs.permute(1, 2, 0)
  226. boxes = pred[0]['boxes'].cpu().numpy()
  227. box_scores = pred[0]['scores'].cpu().numpy()
  228. # 可视化预测结
  229. fig, ax = plt.subplots(figsize=(10, 10))
  230. ax.imshow(np.array(im))
  231. for idx, box in enumerate(boxes):
  232. if box_scores[idx] < 0.7:
  233. continue
  234. x0, y0, x1, y1 = box
  235. ax.add_patch(
  236. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  237. t_end = time.time()
  238. print(f'show_box used:{t_end - t_start}')
  239. plt.show()
  240. # 将show_line与show_box合并,传入参数确定显示框还是线 都不显示,输出原图
  241. # def show_box_or_line(imgs, pred, show_line=False, show_box=False):
  242. # col = [
  243. # '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  244. # '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  245. # '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  246. # '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  247. # '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  248. # '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  249. # '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  250. # '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  251. # '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  252. # '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  253. # '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  254. # '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  255. # '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  256. # '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  257. # '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  258. # '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  259. # '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  260. # '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  261. # '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  262. # '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  263. # ]
  264. # # print(len(col))
  265. # im = imgs.permute(1, 2, 0)
  266. # boxes = pred[0]['boxes'].cpu().numpy()
  267. # line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  268. #
  269. # # 可视化预测结
  270. # fig, ax = plt.subplots(figsize=(10, 10))
  271. # ax.imshow(np.array(im))
  272. #
  273. # if show_box:
  274. # for idx, box in enumerate(boxes):
  275. # x0, y0, x1, y1 = box
  276. # ax.add_patch(
  277. # plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  278. #
  279. # if show_line:
  280. # for idx, (a, b) in enumerate(line):
  281. # ax.scatter(a[1], a[0], c='#871F78', s=2)
  282. # ax.scatter(b[1], b[0], c='#871F78', s=2)
  283. # ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  284. #
  285. # plt.show()
  286. # 将show_line与show_box合并,传入参数确定显示框还是线 一起画
  287. def show_box_and_line(imgs, pred, show_line=False, show_box=False):
  288. col = [
  289. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  290. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  291. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  292. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  293. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  294. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  295. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  296. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  297. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  298. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  299. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  300. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  301. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  302. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  303. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  304. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  305. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  306. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  307. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  308. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  309. ]
  310. # print(len(col))
  311. im = imgs.permute(1, 2, 0)
  312. boxes = pred[0]['boxes'].cpu().numpy()
  313. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  314. # 可视化预测结
  315. fig, axs = plt.subplots(1, 2, figsize=(10, 10))
  316. if show_box:
  317. axs[0].imshow(np.array(im))
  318. for idx, box in enumerate(boxes):
  319. x0, y0, x1, y1 = box
  320. axs[0].add_patch(
  321. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  322. axs[0].set_title('Boxes')
  323. if show_line:
  324. axs[1].imshow(np.array(im))
  325. for idx, (a, b) in enumerate(line):
  326. axs[1].scatter(a[1], a[0], c='#871F78', s=2)
  327. axs[1].scatter(b[1], b[0], c='#871F78', s=2)
  328. axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  329. axs[1].set_title('Lines')
  330. # 调整子图之间的距离,防止标题和标签重叠
  331. plt.tight_layout()
  332. plt.show()
  333. def set_thresholds(threshold):
  334. if isinstance(threshold, list):
  335. if len(threshold) != 2:
  336. raise ValueError("Threshold list must contain exactly two elements.")
  337. a, b = threshold
  338. elif isinstance(threshold, (int, float)):
  339. a = b = threshold
  340. else:
  341. raise TypeError("Threshold must be either a list of two numbers or a single number.")
  342. return a, b
  343. def color():
  344. return [
  345. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  346. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  347. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  348. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  349. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  350. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  351. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  352. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  353. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  354. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  355. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  356. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  357. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  358. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  359. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  360. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  361. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  362. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  363. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  364. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  365. ]
  366. def show_all(imgs, pred, threshold, save_path, show):
  367. col = color()
  368. box_th, line_th = set_thresholds(threshold)
  369. im = imgs.permute(1, 2, 0)
  370. boxes = pred[0]['boxes'].cpu().numpy()
  371. box_scores = pred[0]['scores'].cpu().numpy()
  372. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  373. line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
  374. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  375. line, line_score = postprocess(line, line_score, diag * 0.01, 0, False)
  376. fig, axs = plt.subplots(1, 3, figsize=(10, 10))
  377. axs[0].imshow(np.array(im))
  378. for idx, box in enumerate(boxes):
  379. if box_scores[idx] < box_th:
  380. continue
  381. x0, y0, x1, y1 = box
  382. axs[0].add_patch(
  383. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  384. axs[0].set_title('Boxes')
  385. axs[1].imshow(np.array(im))
  386. for idx, (a, b) in enumerate(line):
  387. if line_score[idx] < line_th:
  388. continue
  389. axs[1].scatter(a[1], a[0], c='#871F78', s=2)
  390. axs[1].scatter(b[1], b[0], c='#871F78', s=2)
  391. axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  392. axs[1].set_title('Lines')
  393. axs[2].imshow(np.array(im))
  394. lines = pred[0]['line'].cpu().numpy()
  395. line_scores = pred[0]['line_score'].cpu().numpy()
  396. idx = 0
  397. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  398. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  399. x0, y0, x1, y1 = box
  400. # 框中无线的跳过
  401. if np.array_equal(line, tmp):
  402. continue
  403. a, b = line
  404. if box_score >= 0.7 or line_score >= 0.9:
  405. axs[2].add_patch(
  406. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  407. axs[2].scatter(a[1], a[0], c='#871F78', s=10)
  408. axs[2].scatter(b[1], b[0], c='#871F78', s=10)
  409. axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  410. idx = idx + 1
  411. axs[2].set_title('Boxes and Lines')
  412. if save_path:
  413. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png')
  414. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  415. plt.savefig(save_path)
  416. print(f"Saved result image to {save_path}")
  417. if show:
  418. # 调整子图之间的距离,防止标题和标签重叠
  419. plt.tight_layout()
  420. plt.show()
  421. def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False):
  422. col = color()
  423. box_th, line_th = set_thresholds(threshold)
  424. im = imgs.permute(1, 2, 0)
  425. boxes = pred[0]['boxes'].cpu().numpy()
  426. box_scores = pred[0]['scores'].cpu().numpy()
  427. line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  428. line_score = pred[-1]['wires']['score'].cpu().numpy()[0]
  429. # 可视化预测结
  430. fig, ax = plt.subplots(figsize=(10, 10))
  431. ax.imshow(np.array(im))
  432. if show_box:
  433. for idx, box in enumerate(boxes):
  434. if box_scores[idx] < box_th:
  435. continue
  436. x0, y0, x1, y1 = box
  437. ax.add_patch(
  438. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  439. if save_path:
  440. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png')
  441. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  442. plt.savefig(save_path)
  443. print(f"Saved result image to {save_path}")
  444. if show_line:
  445. for idx, (a, b) in enumerate(line):
  446. if line_score[idx] < line_th:
  447. continue
  448. ax.scatter(a[1], a[0], c='#871F78', s=2)
  449. ax.scatter(b[1], b[0], c='#871F78', s=2)
  450. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  451. if save_path:
  452. save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png')
  453. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  454. plt.savefig(save_path)
  455. print(f"Saved result image to {save_path}")
  456. plt.show()