| 
														
															@@ -1,3 +1,4 @@ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import os 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import time 
														 | 
														
														 | 
														
															 import time 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import torch 
														 | 
														
														 | 
														
															 import torch 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -7,6 +8,8 @@ from torchvision import transforms 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from models.wirenet.postprocess import postprocess 
														 | 
														
														 | 
														
															 from models.wirenet.postprocess import postprocess 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+from datetime import datetime 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 def box_line(pred): 
														 | 
														
														 | 
														
															 def box_line(pred): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ''' 
														 | 
														
														 | 
														
															     ''' 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -188,7 +191,8 @@ def show_line(imgs, pred, t_start): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
														 | 
														
														 | 
														
															     diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     line, line_score = postprocess(line, line_score, diag * 0.01, 0, False) 
														 | 
														
														 | 
														
															     line, line_score = postprocess(line, line_score, diag * 0.01, 0, False) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    print(f'lines num:{len(line)}') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # count = np.sum(line_score > 0.9) 
														 | 
														
														 | 
														
															     # count = np.sum(line_score > 0.9) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # print(f'draw line number:{count}') 
														 | 
														
														 | 
														
															     # print(f'draw line number:{count}') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -197,8 +201,8 @@ def show_line(imgs, pred, t_start): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ax.imshow(np.array(im)) 
														 | 
														
														 | 
														
															     ax.imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     for idx, (a, b) in enumerate(line): 
														 | 
														
														 | 
														
															     for idx, (a, b) in enumerate(line): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        if line_score[idx] < 0.9: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            continue 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # if line_score[idx] < 0.7: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #     continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ax.scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
														
														 | 
														
															         ax.scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ax.scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
														
														 | 
														
															         ax.scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
														
														 | 
														
															         ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -281,7 +285,55 @@ def show_box(imgs, pred, t_start): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # 将show_line与show_box合并,传入参数确定显示框还是线  都不显示,输出原图 
														 | 
														
														 | 
														
															 # 将show_line与show_box合并,传入参数确定显示框还是线  都不显示,输出原图 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def show_box_or_line(imgs, pred, show_line=False, show_box=False): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# def show_box_or_line(imgs, pred, show_line=False, show_box=False): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     col = [ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#bfbfbf', '#969696', '#737373', '#525252', '#252525', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     ] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     # print(len(col)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     im = imgs.permute(1, 2, 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     boxes = pred[0]['boxes'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     # 可视化预测结 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     fig, ax = plt.subplots(figsize=(10, 10)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     ax.imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     if show_box: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         for idx, box in enumerate(boxes): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             x0, y0, x1, y1 = box 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             ax.add_patch( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     if show_line: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         for idx, (a, b) in enumerate(line): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             ax.scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             ax.scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 将show_line与show_box合并,传入参数确定显示框还是线  一起画 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def show_box_and_line(imgs, pred, show_line=False, show_box=False): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     col = [ 
														 | 
														
														 | 
														
															     col = [ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', 
														 | 
														
														 | 
														
															         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', 
														 | 
														
														 | 
														
															         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -310,27 +362,43 @@ def show_box_or_line(imgs, pred, show_line=False, show_box=False): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
														 | 
														
														 | 
														
															     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # 可视化预测结 
														 | 
														
														 | 
														
															     # 可视化预测结 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    fig, ax = plt.subplots(figsize=(10, 10)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    ax.imshow(np.array(im)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    fig, axs = plt.subplots(1, 2, figsize=(10, 10)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     if show_box: 
														 | 
														
														 | 
														
															     if show_box: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[0].imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         for idx, box in enumerate(boxes): 
														 | 
														
														 | 
														
															         for idx, box in enumerate(boxes): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             x0, y0, x1, y1 = box 
														 | 
														
														 | 
														
															             x0, y0, x1, y1 = box 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            ax.add_patch( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[0].add_patch( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
														
														 | 
														
															                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[0].set_title('Boxes') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     if show_line: 
														 | 
														
														 | 
														
															     if show_line: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[1].imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         for idx, (a, b) in enumerate(line): 
														 | 
														
														 | 
														
															         for idx, (a, b) in enumerate(line): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            ax.scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            ax.scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[1].scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[1].scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[1].set_title('Lines') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 调整子图之间的距离,防止标题和标签重叠 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.tight_layout() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     plt.show() 
														 | 
														
														 | 
														
															     plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-# 将show_line与show_box合并,传入参数确定显示框还是线  一起画 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def show_box_and_line(imgs, pred, show_line=False, show_box=False): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    col = [ 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def set_thresholds(threshold): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if isinstance(threshold, list): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if len(threshold) != 2: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            raise ValueError("Threshold list must contain exactly two elements.") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        a, b = threshold 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    elif isinstance(threshold, (int, float)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        a = b = threshold 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    else: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        raise TypeError("Threshold must be either a list of two numbers or a single number.") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    return a, b 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def color(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    return  [ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', 
														 | 
														
														 | 
														
															         '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', 
														 | 
														
														 | 
														
															         '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', 
														 | 
														
														 | 
														
															         '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -352,30 +420,115 @@ def show_box_and_line(imgs, pred, show_line=False, show_box=False): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', 
														 | 
														
														 | 
														
															         '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' 
														 | 
														
														 | 
														
															         '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ] 
														 | 
														
														 | 
														
															     ] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    # print(len(col)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def show_all(imgs, pred, threshold, save_path, show): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    col = color() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    box_th, line_th = set_thresholds(threshold) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     im = imgs.permute(1, 2, 0) 
														 | 
														
														 | 
														
															     im = imgs.permute(1, 2, 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     boxes = pred[0]['boxes'].cpu().numpy() 
														 | 
														
														 | 
														
															     boxes = pred[0]['boxes'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    box_scores = pred[0]['scores'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
														 | 
														
														 | 
														
															     line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    fig, axs = plt.subplots(1, 3, figsize=(10, 10)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    axs[0].imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    for idx, box in enumerate(boxes): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if box_scores[idx] < box_th: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        x0, y0, x1, y1 = box 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[0].add_patch( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    axs[0].set_title('Boxes') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    axs[1].imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    for idx, (a, b) in enumerate(line): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if line_score[idx] < line_th: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[1].scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[1].scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    axs[1].set_title('Lines') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    axs[2].imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    lines = pred[0]['line'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    line_scores = pred[0]['line_score'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    idx = 0 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        x0, y0, x1, y1 = box 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 框中无线的跳过 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if np.array_equal(line, tmp): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        a, b = line 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if box_score >= 0.7 or line_score >= 0.9: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[2].add_patch( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[2].scatter(a[1], a[0], c='#871F78', s=10) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[2].scatter(b[1], b[0], c='#871F78', s=10) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            idx = idx + 1 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    axs[2].set_title('Boxes and Lines') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if save_path: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        os.makedirs(os.path.dirname(save_path), exist_ok=True) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.savefig(save_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        print(f"Saved result image to {save_path}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if show: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 调整子图之间的距离,防止标题和标签重叠 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.tight_layout() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    col = color() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    box_th, line_th = set_thresholds(threshold) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    im = imgs.permute(1, 2, 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    boxes = pred[0]['boxes'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    box_scores = pred[0]['scores'].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # 可视化预测结 
														 | 
														
														 | 
														
															     # 可视化预测结 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    fig, axs = plt.subplots(1, 2, figsize=(10, 10)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    fig, ax = plt.subplots(figsize=(10, 10)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    ax.imshow(np.array(im)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     if show_box: 
														 | 
														
														 | 
														
															     if show_box: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        axs[0].imshow(np.array(im)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         for idx, box in enumerate(boxes): 
														 | 
														
														 | 
														
															         for idx, box in enumerate(boxes): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            if box_scores[idx] < box_th: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             x0, y0, x1, y1 = box 
														 | 
														
														 | 
														
															             x0, y0, x1, y1 = box 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            axs[0].add_patch( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            ax.add_patch( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
														
														 | 
														
															                 plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        axs[0].set_title('Boxes') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if save_path: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            os.makedirs(os.path.dirname(save_path), exist_ok=True) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            plt.savefig(save_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            print(f"Saved result image to {save_path}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     if show_line: 
														 | 
														
														 | 
														
															     if show_line: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        axs[1].imshow(np.array(im)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         for idx, (a, b) in enumerate(line): 
														 | 
														
														 | 
														
															         for idx, (a, b) in enumerate(line): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            axs[1].scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            axs[1].scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        axs[1].set_title('Lines') 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            if line_score[idx] < line_th: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            ax.scatter(a[1], a[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            ax.scatter(b[1], b[0], c='#871F78', s=2) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if save_path: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            os.makedirs(os.path.dirname(save_path), exist_ok=True) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    # 调整子图之间的距离,防止标题和标签重叠 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    plt.tight_layout() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    plt.show() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            plt.savefig(save_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            print(f"Saved result image to {save_path}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.show() 
														 |