| 
					
				 | 
			
			
				@@ -33,6 +33,8 @@ import matplotlib as mpl 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from skimage import io 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import os.path as osp 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torchvision.utils import draw_bounding_boxes 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from torchvision import transforms 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.wirenet.postprocess import postprocess 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 FEATURE_DIM = 8 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -583,32 +585,91 @@ def imshow(im): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     plt.colorbar(sm, fraction=0.046) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     plt.xlim([0, im.shape[0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     plt.ylim([im.shape[0], 0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# def _plot_samples(img, i, result, prefix, epoch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     print(f"prefix:{prefix}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     def draw_vecl(lines, sline, juncs, junts, fn): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         directory = os.path.dirname(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if not os.path.exists(directory): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             os.makedirs(directory) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         imshow(img.permute(1, 2, 0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if len(lines) > 0 and not (lines[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             for i, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 if i > 0 and (lines[i] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if not (juncs[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             for i, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 if i > 0 and (i == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             for i, j in enumerate(junts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 if i > 0 and (i == junts[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         plt.savefig(fn), plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     rjuncs = result["juncs"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     rjunts = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     if "junts" in result: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         rjunts = result["junts"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     vecl_result = result["lines"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     score = result["score"][i].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def _plot_samples(img, i, result, prefix, epoch, writer): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # print(f"prefix:{prefix}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def _plot_samples(img, i, result, prefix, epoch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    print(f"prefix:{prefix}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def draw_vecl(lines, sline, juncs, junts, fn): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if not os.path.exists(fn): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            os.makedirs(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        imshow(img.permute(1, 2, 0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 确保目录存在 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        directory = os.path.dirname(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not os.path.exists(directory): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            os.makedirs(directory) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 绘制图像 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.figure() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.imshow(img.permute(1, 2, 0).cpu().numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.axis('off')  # 可选:关闭坐标轴 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if len(lines) > 0 and not (lines[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for i, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if i > 0 and (lines[i] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for idx, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if idx > 0 and (lines[idx] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if not (juncs[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for i, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if i > 0 and (i == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for idx, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if idx > 0 and (j == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for i, j in enumerate(junts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if i > 0 and (i == junts[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for idx, j in enumerate(junts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if idx > 0 and (j == junts[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        plt.savefig(fn), plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 将matplotlib图像转换为numpy数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fig = plt.gcf() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fig.canvas.draw() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fig.canvas.get_width_height()[::-1] + (3,)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return image_from_plot 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 获取结果数据并转换为numpy数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     rjuncs = result["juncs"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     rjunts = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if "junts" in result: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -617,10 +678,62 @@ def _plot_samples(img, i, result, prefix, epoch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     vecl_result = result["lines"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     score = result["score"][i].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 调用绘图函数并获取图像 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    image_path = f"{prefix}_vecl_b.jpg" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 将numpy数组转换为torch tensor,并写入TensorBoard 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    image_tensor = transforms.ToTensor()(image_array) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    writer.add_image(f'ori_epoch', img, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def show_line(img, pred, prefix, epoch, write): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fn = f"{prefix}_line.jpg" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    directory = os.path.dirname(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not os.path.exists(directory): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        os.makedirs(directory) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    H = pred 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = img.permute(1, 2, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    scores = H["score"][0].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for i in range(1, len(lines)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if (lines[i] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            lines = lines[:i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            scores = scores[:i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # postprocess lines to remove overlapped lines 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for i, t in enumerate([0.5]): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.gca().set_axis_off() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.margins(0, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for (a, b), s in zip(nlines, nscores): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if s < t: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            plt.scatter(a[1], a[0], **PLTOPTS) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            plt.scatter(b[1], b[0], **PLTOPTS) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.gca().xaxis.set_major_locator(plt.NullLocator()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.gca().yaxis.set_major_locator(plt.NullLocator()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.imshow(im) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.savefig(fn, bbox_inches="tight") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        img2 = cv2.imread(fn)  # 预测图 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # img1 = im.resize(img2.shape)  # 原图 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        writer.add_image("output", img2, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 if __name__ == '__main__': 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -698,27 +811,27 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(f"epoch:{epoch}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for imgs, targets in data_loader_train: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            loss = _loss(losses) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            loss.backward() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            optimizer.step() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            writer_loss(writer, losses, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    pred = model(move_to_device(imgs, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    # print(f"pred:{pred}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    if batch_idx == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        print(imgs[0].shape)  # [3,512,512] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                        _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # for imgs, targets in data_loader_train: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # loss = _loss(losses) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # print(loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # loss.backward() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # optimizer.step() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # writer_loss(writer, losses, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                pred = model(move_to_device(imgs, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # print(f"pred:{pred}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if batch_idx == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(imgs[0].shape)  # [3,512,512] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # imgs, targets = next(iter(data_loader)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # 
			 |