| 
														
															@@ -10,6 +10,7 @@ import torch.nn.functional as F 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # from torchinfo import summary 
														 | 
														
														 | 
														
															 # from torchinfo import summary 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torchvision.io import read_image 
														 | 
														
														 | 
														
															 from torchvision.io import read_image 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torchvision.models import resnet50, ResNet50_Weights 
														 | 
														
														 | 
														
															 from torchvision.models import resnet50, ResNet50_Weights 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+from torchvision.models import resnet18, ResNet18_Weights 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights 
														 | 
														
														 | 
														
															 from torchvision.models.detection import FasterRCNN, MaskRCNN_ResNet50_FPN_V2_Weights 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torchvision.models.detection._utils import overwrite_eps 
														 | 
														
														 | 
														
															 from torchvision.models.detection._utils import overwrite_eps 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers 
														 | 
														
														 | 
														
															 from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -522,26 +523,55 @@ class WirepointPredictor(nn.Module): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)] 
														 | 
														
														 | 
														
															             jcs = [xy[i, score[i] > 0.03] for i in range(n_type)] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             return line, label.float(), feat, jcs 
														 | 
														
														 | 
														
															             return line, label.float(), feat, jcs 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# def wirepointrcnn_resnet50_fpn( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         *, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         progress: bool = True, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         num_classes: Optional[int] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         num_keypoints: Optional[int] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         trainable_backbone_layers: Optional[int] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         **kwargs: Any, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# ) -> WirepointRCNN: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     weights_backbone = ResNet50_Weights.verify(weights_backbone) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     is_trained = weights is not None or weights_backbone is not None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     model = WirepointRCNN(backbone, num_classes=5, **kwargs) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     if weights is not None: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         model.load_state_dict(weights.get_state_dict(progress=progress)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             overwrite_eps(model, 0.0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     return model 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def wirepointrcnn_resnet50_fpn( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def wirepointrcnn_resnet18_fpn( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         *, 
														 | 
														
														 | 
														
															         *, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, 
														 | 
														
														 | 
														
															         weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         progress: bool = True, 
														 | 
														
														 | 
														
															         progress: bool = True, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         num_classes: Optional[int] = None, 
														 | 
														
														 | 
														
															         num_classes: Optional[int] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         num_keypoints: Optional[int] = None, 
														 | 
														
														 | 
														
															         num_keypoints: Optional[int] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1, 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        weights_backbone: Optional[ResNet18_Weights] = ResNet18_Weights.IMAGENET1K_V1, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         trainable_backbone_layers: Optional[int] = None, 
														 | 
														
														 | 
														
															         trainable_backbone_layers: Optional[int] = None, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         **kwargs: Any, 
														 | 
														
														 | 
														
															         **kwargs: Any, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 ) -> WirepointRCNN: 
														 | 
														
														 | 
														
															 ) -> WirepointRCNN: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) 
														 | 
														
														 | 
														
															     weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    weights_backbone = ResNet50_Weights.verify(weights_backbone) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    weights_backbone = ResNet18_Weights.verify(weights_backbone) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     is_trained = weights is not None or weights_backbone is not None 
														 | 
														
														 | 
														
															     is_trained = weights is not None or weights_backbone is not None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) 
														 | 
														
														 | 
														
															     trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d 
														 | 
														
														 | 
														
															     norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    backbone = resnet18(weights=weights_backbone, progress=progress, norm_layer=norm_layer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) 
														 | 
														
														 | 
														
															     backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     model = WirepointRCNN(backbone, num_classes=5, **kwargs) 
														 | 
														
														 | 
														
															     model = WirepointRCNN(backbone, num_classes=5, **kwargs) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -577,27 +607,128 @@ sm.set_array([]) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 def c(x): 
														 | 
														
														 | 
														
															 def c(x): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     return sm.to_rgba(x) 
														 | 
														
														 | 
														
															     return sm.to_rgba(x) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def imshow(im): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.tight_layout() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.imshow(im) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.colorbar(sm, fraction=0.046) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.xlim([0, im.shape[0]]) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    plt.ylim([im.shape[0], 0]) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# def _plot_samples(img, i, result, prefix, epoch): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     print(f"prefix:{prefix}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     def draw_vecl(lines, sline, juncs, junts, fn): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         directory = os.path.dirname(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if not os.path.exists(directory): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             os.makedirs(directory) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         imshow(img.permute(1, 2, 0)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if len(lines) > 0 and not (lines[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             for i, ((a, b), s) in enumerate(zip(lines, sline)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 if i > 0 and (lines[i] == lines[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if not (juncs[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             for i, j in enumerate(juncs): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 if i > 0 and (i == juncs[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#             for i, j in enumerate(junts): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 if i > 0 and (i == junts[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                     break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#                 plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         plt.savefig(fn), plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     rjuncs = result["juncs"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     rjunts = None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     if "junts" in result: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#         rjunts = result["junts"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     vecl_result = result["lines"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     score = result["score"][i].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # 
														 | 
														
														 | 
														
															 # 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-# def imshow(im): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     plt.close() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     plt.tight_layout() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     plt.imshow(im) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     plt.colorbar(sm, fraction=0.046) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     plt.xlim([0, im.shape[0]]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     plt.ylim([im.shape[0], 0]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-#     # plt.show() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+#     writer.add_image(f'output_epoch_{epoch}', img1, global_step=epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def _plot_samples(img, i, result, prefix, epoch, writer): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # print(f"prefix:{prefix}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    def draw_vecl(lines, sline, juncs, junts, fn): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 确保目录存在 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        directory = os.path.dirname(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if not os.path.exists(directory): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            os.makedirs(directory) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 绘制图像 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.figure() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.imshow(img.permute(1, 2, 0).cpu().numpy()) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.axis('off')  # 可选:关闭坐标轴 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if len(lines) > 0 and not (lines[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for idx, ((a, b), s) in enumerate(zip(lines, sline)): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if idx > 0 and (lines[idx] == lines[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if not (juncs[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for idx, j in enumerate(juncs): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if idx > 0 and (j == juncs[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.scatter(j[1], j[0], c="red", s=20, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            for idx, j in enumerate(junts): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                if idx > 0 and (j == junts[0]).all(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    break 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                plt.scatter(j[1], j[0], c="blue", s=20, zorder=100) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 将matplotlib图像转换为numpy数组 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.tight_layout() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        fig = plt.gcf() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        fig.canvas.draw() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            fig.canvas.get_width_height()[::-1] + (3,)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return image_from_plot 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def show_line(img, pred,  epoch, write): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    im = img.permute(1, 2, 0) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    writer.add_image("ori", im, epoch, dataformats="HWC") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 获取结果数据并转换为numpy数组 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    rjuncs = result["juncs"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    rjunts = None 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if "junts" in result: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        rjunts = result["junts"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                                      colors="yellow", width=1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    writer.add_image("boxes", boxed_image.permute(1, 2, 0), epoch, dataformats="HWC") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    vecl_result = result["lines"][i].cpu().numpy() * 4 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    score = result["score"][i].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 调用绘图函数并获取图像 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    image_path = f"{prefix}_vecl_b.jpg" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    image_array = draw_vecl(vecl_result, score, rjuncs, rjunts, image_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # 将numpy数组转换为torch tensor,并写入TensorBoard 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    image_tensor = transforms.ToTensor()(image_array) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    writer.add_image(f'output_epoch', image_tensor, global_step=epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    writer.add_image(f'ori_epoch', img, global_step=epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+def show_line(img, pred, prefix, epoch, write): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    fn = f"{prefix}_line.jpg" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    directory = os.path.dirname(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if not os.path.exists(directory): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        os.makedirs(directory) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    print(fn) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} 
														 | 
														
														 | 
														
															     PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5} 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    H = pred[1]['wires'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    H = pred 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    im = img.permute(1, 2, 0) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] 
														 | 
														
														 | 
														
															     lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     scores = H["score"][0].cpu().numpy() 
														 | 
														
														 | 
														
															     scores = H["score"][0].cpu().numpy() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     for i in range(1, len(lines)): 
														 | 
														
														 | 
														
															     for i in range(1, len(lines)): 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -623,14 +754,14 @@ def show_line(img, pred,  epoch, write): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         plt.gca().xaxis.set_major_locator(plt.NullLocator()) 
														 | 
														
														 | 
														
															         plt.gca().xaxis.set_major_locator(plt.NullLocator()) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         plt.gca().yaxis.set_major_locator(plt.NullLocator()) 
														 | 
														
														 | 
														
															         plt.gca().yaxis.set_major_locator(plt.NullLocator()) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         plt.imshow(im) 
														 | 
														
														 | 
														
															         plt.imshow(im) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        plt.tight_layout() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        fig = plt.gcf() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        fig.canvas.draw() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            fig.canvas.get_width_height()[::-1] + (3,)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.savefig(fn, bbox_inches="tight") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        plt.show() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         plt.close() 
														 | 
														
														 | 
														
															         plt.close() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        img2 = transforms.ToTensor()(image_from_plot) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        img2 = cv2.imread(fn)  # 预测图 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # img1 = im.resize(img2.shape)  # 原图 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # writer.add_images(f"{epoch}", torch.tensor([img1, img2]), dataformats='NHWC') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         writer.add_image("output", img2, epoch) 
														 | 
														
														 | 
														
															         writer.add_image("output", img2, epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -669,7 +800,11 @@ if __name__ == '__main__': 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn 
														 | 
														
														 | 
														
															         dataset_val, batch_sampler=val_batch_sampler, num_workers=0, collate_fn=val_collate_fn 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ) 
														 | 
														
														 | 
														
															     ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    model = wirepointrcnn_resnet50_fpn().to(device) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    model = wirepointrcnn_resnet18_fpn().to(device) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # print(model) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # model1 = wirepointrcnn_resnet50_fpn().to(device) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # print(model1) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr']) 
														 | 
														
														 | 
														
															     optimizer = torch.optim.Adam(model.parameters(), lr=cfg['optim']['lr']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     writer = SummaryWriter(cfg['io']['logdir']) 
														 | 
														
														 | 
														
															     writer = SummaryWriter(cfg['io']['logdir']) 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -709,23 +844,27 @@ if __name__ == '__main__': 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         print(f"epoch:{epoch}") 
														 | 
														
														 | 
														
															         print(f"epoch:{epoch}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         model.train() 
														 | 
														
														 | 
														
															         model.train() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # for imgs, targets in data_loader_train: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # loss = _loss(losses) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # print(loss) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # optimizer.zero_grad() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # loss.backward() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # optimizer.step() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # writer_loss(writer, losses, epoch) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        model.eval() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        with torch.no_grad(): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                pred = model(move_to_device(imgs, device))  # # pred[0].keys()   ['boxes', 'labels', 'scores'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                # print(f"pred:{pred}") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                if batch_idx == 0: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                    show_line(imgs[0], pred,  epoch, writer) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        for imgs, targets in data_loader_train: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            losses = model(move_to_device(imgs, device), move_to_device(targets, device)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            loss = _loss(losses) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            print(f"loss:{loss}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            optimizer.zero_grad() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            loss.backward() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            optimizer.step() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            writer_loss(writer, losses, epoch) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # model.eval() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # with torch.no_grad(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #     for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #         pred = model(move_to_device(imgs, device)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #         print(f"pred:{pred}") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        # 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #         if batch_idx == 0: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #             result = pred[1]['wires']  # pred[0].keys()   ['boxes', 'labels', 'scores'] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #             print(imgs[0].shape)  # [3,512,512] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #             # imshow(imgs[0].permute(1, 2, 0))  # 改为(512, 512, 3) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #             _plot_samples(imgs[0], 0, result, f"{cfg['io']['logdir']}/{epoch}/", epoch, writer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        #             show_line(imgs[0], result, f"{cfg['io']['logdir']}/{epoch}", epoch, writer) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # imgs, targets = next(iter(data_loader)) 
														 | 
														
														 | 
														
															 # imgs, targets = next(iter(data_loader)) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # 
														 | 
														
														 | 
														
															 # 
														 |