| 
														
															@@ -33,6 +33,8 @@ import matplotlib as mpl 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from skimage import io 
														 | 
														
														 | 
														
															 from skimage import io 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import os.path as osp 
														 | 
														
														 | 
														
															 import os.path as osp 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torchvision.utils import draw_bounding_boxes 
														 | 
														
														 | 
														
															 from torchvision.utils import draw_bounding_boxes 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+from torchvision import transforms 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+from models.wirenet.postprocess import postprocess 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 FEATURE_DIM = 8 
														 | 
														
														 | 
														
															 FEATURE_DIM = 8 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -583,32 +585,91 @@ def imshow(im): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     plt.colorbar(sm, fraction=0.046) 
														 | 
														
														 | 
														
															     plt.colorbar(sm, fraction=0.046) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     plt.xlim([0, im.shape[0]]) 
														 | 
														
														 | 
														
															     plt.xlim([0, im.shape[0]]) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     plt.ylim([im.shape[0], 0]) 
														 | 
														
														 | 
														
															     plt.ylim([im.shape[0], 0]) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    plt.show() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# def _plot_samples(img, i, result, prefix, epoch): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     print(f"prefix:{prefix}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     def draw_vecl(lines, sline, juncs, junts, fn): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         directory = os.path.dirname(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if not os.path.exists(directory): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             os.makedirs(directory) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         imshow(img.permute(1, 2, 0)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if len(lines) > 0 and not (lines[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             for i, ((a, b), s) in enumerate(zip(lines, sline)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 if i > 0 and (lines[i] == lines[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if not (juncs[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             for i, j in enumerate(juncs): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 if i > 0 and (i == juncs[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             for i, j in enumerate(junts): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 if i > 0 and (i == junts[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         plt.savefig(fn), plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     rjuncs = result["juncs"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     rjunts = None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     if "junts" in result: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         rjunts = result["junts"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     vecl_result = result["lines"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     score = result["score"][i].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def _plot_samples(img, i, result, prefix, epoch, writer): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # print(f"prefix:{prefix}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def _plot_samples(img, i, result, prefix, epoch): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    print(f"prefix:{prefix}") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     def draw_vecl(lines, sline, juncs, junts, fn): 
														 | 
														
														 | 
														
															     def draw_vecl(lines, sline, juncs, junts, fn): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        if not os.path.exists(fn): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            os.makedirs(fn) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        imshow(img.permute(1, 2, 0)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 确保目录存在 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        directory = os.path.dirname(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if not os.path.exists(directory): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            os.makedirs(directory) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 绘制图像 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.figure() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.imshow(img.permute(1, 2, 0).cpu().numpy()) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.axis('off')  # 可选:关闭坐标轴 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if len(lines) > 0 and not (lines[0] == 0).all(): 
														 | 
														
														 | 
														
															         if len(lines) > 0 and not (lines[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            for i, ((a, b), s) in enumerate(zip(lines, sline)): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                if i > 0 and (lines[i] == lines[0]).all(): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for idx, ((a, b), s) in enumerate(zip(lines, sline)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if idx > 0 and (lines[idx] == lines[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                     break 
														 | 
														
														 | 
														
															                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if not (juncs[0] == 0).all(): 
														 | 
														
														 | 
														
															         if not (juncs[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            for i, j in enumerate(juncs): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                if i > 0 and (i == juncs[0]).all(): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for idx, j in enumerate(juncs): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if idx > 0 and (j == juncs[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                     break 
														 | 
														
														 | 
														
															                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
														 | 
														
														 | 
														
															         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            for i, j in enumerate(junts): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                if i > 0 and (i == junts[0]).all(): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for idx, j in enumerate(junts): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if idx > 0 and (j == junts[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															                     break 
														 | 
														
														 | 
														
															                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        plt.savefig(fn), plt.close() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 将matplotlib图像转换为numpy数组 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.tight_layout() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        fig = plt.gcf() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        fig.canvas.draw() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            fig.canvas.get_width_height()[::-1] + (3,)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return image_from_plot 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 获取结果数据并转换为numpy数组 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     rjuncs = result["juncs"][i].cpu().numpy() * 4 
														 | 
														
														 | 
														
															     rjuncs = result["juncs"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     rjunts = None 
														 | 
														
														 | 
														
															     rjunts = None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     if "junts" in result: 
														 | 
														
														 | 
														
															     if "junts" in result: 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -617,10 +678,62 @@ def _plot_samples(img, i, result, prefix, epoch): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     vecl_result = result["lines"][i].cpu().numpy() * 4 
														 | 
														
														 | 
														
															     vecl_result = result["lines"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     score = result["score"][i].cpu().numpy() 
														 | 
														
														 | 
														
															     score = result["score"][i].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 调用绘图函数并获取图像 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    image_path = f"{prefix}_vecl_b.jpg" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 将numpy数组转换为torch tensor,并写入TensorBoard 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    image_tensor = transforms.ToTensor()(image_array) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    writer.add_image(f'ori_epoch', img, global_step=epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def show_line(img, pred, prefix, epoch, write): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    fn = f"{prefix}_line.jpg" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    directory = os.path.dirname(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if not os.path.exists(directory): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        os.makedirs(directory) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    print(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    H = pred 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    im = img.permute(1, 2, 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    scores = H["score"][0].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    for i in range(1, len(lines)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if (lines[i] == lines[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            lines = lines[:i] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            scores = scores[:i] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # postprocess lines to remove overlapped lines 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    for i, t in enumerate([0.5]): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.gca().set_axis_off() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.margins(0, 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        for (a, b), s in zip(nlines, nscores): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            if s < t: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                continue 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            plt.scatter(a[1], a[0], **PLTOPTS) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            plt.scatter(b[1], b[0], **PLTOPTS) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.gca().xaxis.set_major_locator(plt.NullLocator()) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.gca().yaxis.set_major_locator(plt.NullLocator()) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.imshow(im) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.savefig(fn, bbox_inches="tight") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        img2 = cv2.imread(fn)  # 预测图 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # img1 = im.resize(img2.shape)  # 原图 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        writer.add_image("output", img2, epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 if __name__ == '__main__': 
														 | 
														
														 | 
														
															 if __name__ == '__main__': 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -698,27 +811,27 @@ if __name__ == '__main__': 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         print(f"epoch:{epoch}") 
														 | 
														
														 | 
														
															         print(f"epoch:{epoch}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         model.train() 
														 | 
														
														 | 
														
															         model.train() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        for imgs, targets in data_loader_train: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            loss = _loss(losses) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            print(loss) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            optimizer.zero_grad() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            loss.backward() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            optimizer.step() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            writer_loss(writer, losses, epoch) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            model.eval() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            with torch.no_grad(): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                    pred = model(move_to_device(imgs, device)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                    # print(f"pred:{pred}") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                    if batch_idx == 0: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                        result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                        print(imgs[0].shape)  # [3,512,512] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                        # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                        _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # for imgs, targets in data_loader_train: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # loss = _loss(losses) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # print(loss) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # optimizer.zero_grad() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # loss.backward() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # optimizer.step() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # writer_loss(writer, losses, epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        model.eval() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        with torch.no_grad(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                pred = model(move_to_device(imgs, device)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                # print(f"pred:{pred}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if batch_idx == 0: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores'] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    print(imgs[0].shape)  # [3,512,512] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # imgs, targets = next(iter(data_loader)) 
														 | 
														
														 | 
														
															 # imgs, targets = next(iter(data_loader)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # 
														 | 
														
														 | 
														
															 # 
														 |