| 
					
				 | 
			
			
				@@ -1,6 +1,7 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Optional, Any 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import cv2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from tensorboardX import SummaryWriter 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -27,9 +28,17 @@ from models.wirenet.wirepoint_dataset import WirePointDataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from tools import utils 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torch.utils.tensorboard import SummaryWriter 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import matplotlib.pyplot as plt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import matplotlib as mpl 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from skimage import io 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import os.path as osp 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 FEATURE_DIM = 8 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+print(f"Using device: {device}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def non_maximum_suppression(a): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ap = F.max_pool2d(a, 3, stride=1, padding=1) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -124,7 +133,7 @@ class WirepointRCNN(FasterRCNN): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if wirepoint_head is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             keypoint_layers = tuple(512 for _ in range(8)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # print(f'keypoinyrcnnHeads inchannels:{out_channels},layers{keypoint_layers}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             wirepoint_head = WirepointHead(out_channels, keypoint_layers) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if wirepoint_predictor is None: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -291,7 +300,7 @@ class WirepointPredictor(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         #     print(f'out:{out.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # outputs=merge_features(outputs,100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         batch, channel, row, col = inputs.shape 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(f'outputs:{inputs.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # print(f'outputs:{inputs.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # print(f'batch:{batch}, channel:{channel}, row:{row}, col:{col}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if targets is not None: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -316,18 +325,18 @@ class WirepointPredictor(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.training = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             t = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "junc_coords": torch.zeros(1, 2), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "jtyp": torch.zeros(1, dtype=torch.uint8), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "junc_map": torch.zeros([1, 1, 128, 128]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "junc_coords": torch.zeros(1, 2).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "jtyp": torch.zeros(1, dtype=torch.uint8).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "line_pos_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "line_neg_idx": torch.zeros(2, 2, dtype=torch.uint8).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "junc_map": torch.zeros([1, 1, 128, 128]).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             wires_targets = [t for b in range(inputs.size(0))] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             wires_meta = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "junc_map": torch.zeros([1, 1, 128, 128]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                "junc_offset": torch.zeros([1, 1, 2, 128, 128]), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "junc_map": torch.zeros([1, 1, 128, 128]).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "junc_offset": torch.zeros([1, 1, 2, 128, 128]).to(device), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         T = wires_meta.copy() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -399,7 +408,6 @@ class WirepointPredictor(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x, y = torch.cat(xs), torch.cat(ys) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         f = torch.cat(fs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = x.reshape(-1, self.n_pts1 * self.dim_loi) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print(f"pstest{ps}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = torch.cat([x, f], 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = x.to(dtype=torch.float32) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         x = self.fc2(x).flatten() 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -443,6 +451,9 @@ class WirepointPredictor(nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             xy_ = xy[..., None, :] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             del x, y, index 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # print(f"xy_.is_cuda: {xy_.is_cuda}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # print(f"junc.is_cuda: {junc.is_cuda}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # dist: [N_TYPE, K, N] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             dist = torch.sum((xy_ - junc) ** 2, -1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             cost, match = torch.min(dist, -1) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -555,6 +566,72 @@ def _loss(losses): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return total_loss 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+cmap = plt.get_cmap("jet") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+norm = mpl.colors.Normalize(vmin=0.4, vmax=1.0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+sm.set_array([]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def c(x): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return sm.to_rgba(x) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def imshow(im): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.imshow(im) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.colorbar(sm, fraction=0.046) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.xlim([0, im.shape[0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.ylim([im.shape[0], 0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def _plot_samples(self, i, index, result, targets, prefix): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fn = self.val_loader.dataset.filelist[index][:-10].replace("_a0", "") + ".png" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    img = io.imread(fn) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    imshow(img), plt.savefig(f"{prefix}_img.jpg"), plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def draw_vecl(lines, sline, juncs, junts, fn): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        imshow(img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if len(lines) > 0 and not (lines[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i, ((a, b), s) in enumerate(zip(lines, sline)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if i > 0 and (lines[i] == lines[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=4) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if not (juncs[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i, j in enumerate(juncs): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if i > 0 and (i == juncs[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.scatter(j[1], j[0], c="red", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if junts is not None and len(junts) > 0 and not (junts[0] == 0).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i, j in enumerate(junts): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if i > 0 and (i == junts[0]).all(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.scatter(j[1], j[0], c="blue", s=64, zorder=100) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.savefig(fn), plt.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    junc = targets[i]["junc"].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    jtyp = targets[i]["jtyp"].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    juncs = junc[jtyp == 0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    junts = junc[jtyp == 1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    rjuncs = result["juncs"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    rjunts = None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if "junts" in result: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        rjunts = result["junts"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    lpre = targets[i]["lpre"].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    vecl_target = targets[i]["lpre_label"].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    vecl_result = result["lines"][i].cpu().numpy() * 4 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    score = result["score"][i].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    lpre = lpre[vecl_target == 1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    draw_vecl(lpre, np.ones(lpre.shape[0]), juncs, junts, f"{prefix}_vecl_a.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    draw_vecl(vecl_result, score, rjuncs, rjunts, f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    img = cv2.imread(f"{prefix}_vecl_a.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    img1 = cv2.imread(f"{prefix}_vecl_b.jpg") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    self.writer.add_images(f"{self.epoch}", torch.tensor([img, img1]), dataformats='NHWC') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     cfg = 'wirenet.yaml' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     cfg = read_yaml(cfg) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -562,15 +639,15 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     print(cfg['model']['n_dyn_negl']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # net = WirepointPredictor() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if torch.cuda.is_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        device_name = "cuda" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        torch.backends.cudnn.deterministic = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        torch.cuda.manual_seed(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print("Let's use", torch.cuda.device_count(), "GPU(s)!") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        print("CUDA is not available") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    device = torch.device(device_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # if torch.cuda.is_available(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    #     device_name = "cuda" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    #     torch.backends.cudnn.deterministic = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    #     torch.cuda.manual_seed(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    #     print("Let's use", torch.cuda.device_count(), "GPU(s)!") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    #     print("CUDA is not available") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # device = torch.device(device_name) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset_train = WirePointDataset(dataset_path=cfg['io']['datadir'], dataset_type='train') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     train_sampler = torch.utils.data.RandomSampler(dataset_train) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -607,17 +684,27 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             return data  # 对于非张量类型的数据不做任何改变 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def writer_loss(writer, losses): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # 记录每个损失项到TensorBoard 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for key, value in losses.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if isinstance(value, dict):  # 如果value本身也是一个字典(例如'loss_wirepoint') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for subkey, subvalue in value['losses'][0].items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    writer.add_scalar(f'{key}/{subkey}', subvalue.item(), epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                writer.add_scalar(key, value.item(), epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def writer_loss(writer, losses, epoch): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # ?????? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for key, value in losses.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if key == 'loss_wirepoint': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # ?? wirepoint ?????? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    for subdict in losses['loss_wirepoint']['losses']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        for subkey, subvalue in subdict.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            # ?? .item() ????? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            writer.add_scalar(f'loss_wirepoint/{subkey}', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                              subvalue.item() if hasattr(subvalue, 'item') else subvalue, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                              epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                elif isinstance(value, torch.Tensor): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # ???????? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    writer.add_scalar(key, value.item(), epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        except Exception as e: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(f"TensorBoard logging error: {e}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for epoch in range(cfg['optim']['max_epoch']): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f"epoch:{epoch}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         model.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for imgs, targets in data_loader_train: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -627,14 +714,18 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             optimizer.zero_grad() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             loss.backward() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             optimizer.step() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            writer_loss(writer, losses) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            writer_loss(writer, losses, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for imgs, targets in data_loader_val: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    print(111) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for batch_idx, (imgs, targets) in enumerate(data_loader_val): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     pred = model(move_to_device(imgs, device)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    print(f"pred:{pred}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    print(f"perd:{pred}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # if batch_idx == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                #     viz = osp.join(cfg['io']['logdir'], "viz", f"{epoch}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                #     H = pred["wires"] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                #     _plot_samples(0, 0, H, targets["wires"], f"{viz}/{epoch}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # imgs, targets = next(iter(data_loader)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -645,83 +736,3 @@ if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # result, losses = model(imgs, targets) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # print(f'result:{result}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # print(f'pred:{losses}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-########### predict############# 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    img_path=r"I:\wirenet_dateset\images\train\00030078_2.png" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    transforms = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT.transforms() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    img = read_image(img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    img = transforms(img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    img = torch.ones((2, 3, 512, 512)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # print(f'img shape:{img.shape}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    onnx_file_path = "./wirenet.onnx" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 导出模型为ONNX格式 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # torch.onnx.export(model, img, onnx_file_path, verbose=True, input_names=['input'], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #                   output_names=['output']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # torch.save(model,'./wirenet.pt') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 5. 指定输出的 ONNX 文件名 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # onnx_file_path = "./wirepoint_rcnn.onnx" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 准备一个示例输入:Mask R-CNN 需要一个图像列表作为输入,每个图像张量的形状应为 [C, H, W] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    img = [torch.ones((3, 800, 800))]  # 示例输入图像大小为 800x800,3个通道 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 指定输出的 ONNX 文件名 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # onnx_file_path = "./mask_rcnn.onnx" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # model_scripted = torch.jit.script(model) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # torch.onnx.export(model_scripted, input, "model.onnx", verbose=True, input_names=["input"], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #                   output_names=["output"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # print(f"Model has been converted to ONNX and saved to {onnx_file_path}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    pred=model(img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    print(f'pred:{pred}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-################################################## end predict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-########## traing ################################### 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # imgs, targets = next(iter(data_loader)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # model.train() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # pred = model(imgs, targets) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # class WrapperModule(torch.nn.Module): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #     def __init__(self, model): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #         super(WrapperModule, self).__init__() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #         self.model = model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #     def forward(self,img, targets): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #         # 在这里处理复杂的输入结构,将其转换为适合追踪的形式 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #         return self.model(img,targets) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # torch.save(model.state_dict(),'./wire.pt') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 包装原始模型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # wrapped_model = WrapperModule(model) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # # model_scripted = torch.jit.trace(wrapped_model,img) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # writer = SummaryWriter('./') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # writer.add_graph(wrapped_model, (imgs,targets)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # writer.close() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # print(f'pred:{pred}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-########## end traing ################################### 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # for imgs,targets in data_loader: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #     print(f'imgs:{imgs}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    #     print(f'targets:{targets}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 |