| 
					
				 | 
			
			
				@@ -10,6 +10,7 @@ import torch.nn.functional as F 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # from torchinfo import summary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torchvision.io import read_image 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torchvision.models import resnet50, ResNet50_Weights 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from torchvision.models import resnet18, ResNet18_Weights 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torchvision.models.detection._utils import overwrite_eps 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -522,26 +523,55 @@ class WirepointPredictor(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return line, label.float(), feat, jcs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# def wirepointrcnn_resnet50_fpn( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         *, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         progress: bool = True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         num_classes: Optional[int] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         num_keypoints: Optional[int] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         trainable_backbone_layers: Optional[int] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         **kwargs: Any, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# ) -> WirepointRCNN: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     weights_backbone = ResNet50_Weights.verify(weights_backbone) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     is_trained = weights is not None or weights_backbone is not None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     model = WirepointRCNN(backbone, num_classes=5, **kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     if weights is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         model.load_state_dict(weights.get_state_dict(progress=progress)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             overwrite_eps(model, 0.0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     return model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def wirepointrcnn_resnet50_fpn( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def wirepointrcnn_resnet18_fpn( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         *, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         progress: bool = True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         num_classes: Optional[int] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         num_keypoints: Optional[int] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         trainable_backbone_layers: Optional[int] = None, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         **kwargs: Any, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 ) -> WirepointRCNN: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    weights_backbone = ResNet50_Weights.verify(weights_backbone) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    weights_backbone = ResNet18_Weights.verify(weights_backbone) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     is_trained = weights is not None or weights_backbone is not None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model = WirepointRCNN(backbone, num_classes=5, **kwargs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -577,27 +607,128 @@ sm.set_array([]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def c(x): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return sm.to_rgba(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def imshow(im): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.imshow(im) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.colorbar(sm, fraction=0.046) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.xlim([0, im.shape[0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.ylim([im.shape[0], 0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# def _plot_samples(img, i, result, prefix, epoch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     print(f"prefix:{prefix}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     def draw_vecl(lines, sline, juncs, junts, fn): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         directory = os.path.dirname(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if not os.path.exists(directory): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             os.makedirs(directory) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         imshow(img.permute(1, 2, 0)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if len(lines) > 0 and not (lines[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             for i, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 if i > 0 and (lines[i] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if not (juncs[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             for i, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 if i > 0 and (i == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#             for i, j in enumerate(junts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 if i > 0 and (i == junts[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                     break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         plt.savefig(fn), plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     rjuncs = result["juncs"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     rjunts = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     if "junts" in result: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#         rjunts = result["junts"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     vecl_result = result["lines"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     score = result["score"][i].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-# def imshow(im): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     plt.imshow(im) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     plt.colorbar(sm, fraction=0.046) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     plt.xlim([0, im.shape[0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     plt.ylim([im.shape[0], 0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#     # plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def _plot_samples(img, i, result, prefix, epoch, writer): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # print(f"prefix:{prefix}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def draw_vecl(lines, sline, juncs, junts, fn): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 确保目录存在 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        directory = os.path.dirname(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not os.path.exists(directory): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            os.makedirs(directory) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 绘制图像 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.figure() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.imshow(img.permute(1, 2, 0).cpu().numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.axis('off')  # 可选:关闭坐标轴 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if len(lines) > 0 and not (lines[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for idx, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if idx > 0 and (lines[idx] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not (juncs[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for idx, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if idx > 0 and (j == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for idx, j in enumerate(junts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if idx > 0 and (j == junts[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 将matplotlib图像转换为numpy数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fig = plt.gcf() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fig.canvas.draw() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            fig.canvas.get_width_height()[::-1] + (3,)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return image_from_plot 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def show_line(img, pred,  epoch, write): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    im = img.permute(1, 2, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    writer.add_image("ori", im, epoch, dataformats="HWC") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 获取结果数据并转换为numpy数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    rjuncs = result["juncs"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    rjunts = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if "junts" in result: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        rjunts = result["junts"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                                      colors="yellow", width=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    vecl_result = result["lines"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    score = result["score"][i].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 调用绘图函数并获取图像 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    image_path = f"{prefix}_vecl_b.jpg" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 将numpy数组转换为torch tensor,并写入TensorBoard 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    image_tensor = transforms.ToTensor()(image_array) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    writer.add_image(f'ori_epoch', img, global_step=epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def show_line(img, pred, prefix, epoch, write): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fn = f"{prefix}_line.jpg" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    directory = os.path.dirname(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not os.path.exists(directory): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        os.makedirs(directory) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    H = pred[1]['wires'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    H = pred 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = img.permute(1, 2, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     scores = H["score"][0].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for i in range(1, len(lines)): 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -623,14 +754,14 @@ def show_line(img, pred,  epoch, write): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         plt.gca().xaxis.set_major_locator(plt.NullLocator()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         plt.gca().yaxis.set_major_locator(plt.NullLocator()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         plt.imshow(im) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fig = plt.gcf() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fig.canvas.draw() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            fig.canvas.get_width_height()[::-1] + (3,)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.savefig(fn, bbox_inches="tight") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        img2 = transforms.ToTensor()(image_from_plot) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        img2 = cv2.imread(fn)  # 预测图 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # img1 = im.resize(img2.shape)  # 原图 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         writer.add_image("output", img2, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -669,7 +800,11 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    model = wirepointrcnn_resnet50_fpn().to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model = wirepointrcnn_resnet18_fpn().to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # print(model) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # model1 = wirepointrcnn_resnet50_fpn().to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # print(model1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     writer = SummaryWriter(cfg['io']['logdir']) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -709,23 +844,27 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(f"epoch:{epoch}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # for imgs, targets in data_loader_train: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # loss = _loss(losses) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # print(loss) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # loss.backward() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # optimizer.step() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # writer_loss(writer, losses, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                pred = model(move_to_device(imgs, device))  # # pred[0].keys()   ['boxes', 'labels', 'scores'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # print(f"pred:{pred}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if batch_idx == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    show_line(imgs[0], pred,  epoch, writer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for imgs, targets in data_loader_train: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            loss = _loss(losses) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(f"loss:{loss}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            loss.backward() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            optimizer.step() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            writer_loss(writer, losses, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #         pred = model(move_to_device(imgs, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #         print(f"pred:{pred}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #         if batch_idx == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #             result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #             print(imgs[0].shape)  # [3,512,512] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #             # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #             _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #             show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # imgs, targets = next(iter(data_loader)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # 
			 |