| 
					
				 | 
			
			
				@@ -0,0 +1,371 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import skimage 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# from models.line_detect.postprocess import show_predict, show_box, show_box_or_line, show_box_and_line, \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#     show_line_optimized, show_line, show_all 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from PIL import Image 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import matplotlib.pyplot as plt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import matplotlib as mpl 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.line_detect.line_net import linenet_resnet50_fpn 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from torchvision import transforms 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# from models.wirenet.postprocess import postprocess 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from models.wirenet.postprocess import postprocess 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from rtree import index 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from datetime import datetime 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def load_best_model(model, save_path, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if os.path.exists(save_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        checkpoint = torch.load(save_path, map_location=device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model.load_state_dict(checkpoint['model_state_dict']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # if optimizer is not None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        epoch = checkpoint['epoch'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        loss = checkpoint['loss'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f"Loaded best model from {save_path} at epoch {epoch} with loss {loss:.4f}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f"No saved model found at {save_path}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def box_line_(imgs, pred, length=False):  # 默认置信度 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = imgs.permute(1, 2, 0).cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line_scores = pred[-1]['wires']['score'].cpu().numpy()[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for idx, box_ in enumerate(pred[0:-1]): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        box = box_['boxes']  # 是一个tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # line = pred[-1]['wires']['lines'][idx].cpu().numpy() / 128 * 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # score = pred[-1]['wires']['score'][idx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # diag = (512 ** 2 + 512 ** 2) ** 0.5 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # line, score = postprocess(line, score, diag * 0.01, 0, False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        line_ = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        score_ = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i in box: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            score_max = 0.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tmp = [[0.0, 0.0], [0.0, 0.0]] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for j in range(len(line)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        line[j][0][0] <= i[3] and line[j][1][0] <= i[3]): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # # 计算线段长度 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # length = np.linalg.norm(line[j][0] - line[j][1]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # if length > score_max: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    #     tmp = line[j] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    #     score_max = score[j] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if score[j] > score_max: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        tmp = line[j] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        score_max = score[j] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            line_.append(tmp) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            score_.append(score_max) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        processed_list = torch.tensor(line_) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        pred[idx]['line'] = processed_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        processed_s_list = torch.tensor(score_) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        pred[idx]['line_score'] = processed_s_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return pred 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def box_line_optimized(pred): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 创建R-tree索引 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    idx = index.Index() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 将所有线段添加到R-tree中 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    lines = pred[-1]['wires']['lines']  # 形状为[1, 2500, 2, 2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    scores = pred[-1]['wires']['score'][0]  # 假设形状为[2500] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 提取并处理所有线段 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for idx_line in range(lines.shape[1]):  # 遍历2500条线段 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        line_tensor = lines[0, idx_line].cpu().numpy() / 128 * 512  # 转换为numpy数组并调整比例 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x_min = float(min(line_tensor[0][0], line_tensor[1][0])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        y_min = float(min(line_tensor[0][1], line_tensor[1][1])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x_max = float(max(line_tensor[0][0], line_tensor[1][0])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        y_max = float(max(line_tensor[0][1], line_tensor[1][1])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        idx.insert(idx_line, (max(0, x_min - 256), max(0, y_min - 256), min(512, x_max + 256), min(512, y_max + 256))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for idx_box, box_ in enumerate(pred[0:-1]): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        box = box_['boxes'].cpu().numpy()  # 确保将张量转换为numpy数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        line_ = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        score_ = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i in box: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            score_max = 0.0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            tmp = [[0.0, 0.0], [0.0, 0.0]] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 获取与当前box可能相交的所有线段 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            possible_matches = list(idx.intersection((i[0], i[1], i[2], i[3]))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for j in possible_matches: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                line_j = lines[0, j].cpu().numpy() / 128 * 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and  # 注意这里交换了x和y 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if scores[j] > score_max: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        tmp = line_j 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        score_max = scores[j] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            line_.append(tmp) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            score_.append(score_max) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        processed_list = torch.tensor(line_) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        pred[idx_box]['line'] = processed_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        processed_s_list = torch.tensor(score_) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        pred[idx_box]['line_score'] = processed_s_list 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return pred 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def set_thresholds(threshold): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if isinstance(threshold, list): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if len(threshold) != 2: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            raise ValueError("Threshold list must contain exactly two elements.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        a, b = threshold 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    elif isinstance(threshold, (int, float)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        a = b = threshold 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        raise TypeError("Threshold must be either a list of two numbers or a single number.") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return a, b 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def color(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return  [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#bfbfbf', '#969696', '#737373', '#525252', '#252525', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def show_all(imgs, pred, threshold, save_path, show): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    col = color() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    box_th, line_th = set_thresholds(threshold) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = imgs.permute(1, 2, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    boxes = pred[0]['boxes'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    box_scores = pred[0]['scores'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line, line_score = postprocess(line, line_score, diag * 0.01, 0, False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fig, axs = plt.subplots(1, 3, figsize=(10, 10)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    axs[0].imshow(np.array(im)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for idx, box in enumerate(boxes): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if box_scores[idx] < box_th: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x0, y0, x1, y1 = box 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        axs[0].add_patch( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    axs[0].set_title('Boxes') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    axs[1].imshow(np.array(im)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for idx, (a, b) in enumerate(line): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if line_score[idx] < line_th: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        axs[1].scatter(a[1], a[0], c='#871F78', s=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        axs[1].scatter(b[1], b[0], c='#871F78', s=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        axs[1].plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    axs[1].set_title('Lines') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    axs[2].imshow(np.array(im)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    lines = pred[0]['line'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line_scores = pred[0]['line_score'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    idx = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x0, y0, x1, y1 = box 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 框中无线的跳过 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if np.array_equal(line, tmp): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        a, b = line 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if box_score >= 0.7 or line_score >= 0.9: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            axs[2].add_patch( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            axs[2].scatter(a[1], a[0], c='#871F78', s=10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            axs[2].scatter(b[1], b[0], c='#871F78', s=10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            axs[2].plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            idx = idx + 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    axs[2].set_title('Boxes and Lines') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if save_path: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box_line.png') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        os.makedirs(os.path.dirname(save_path), exist_ok=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.savefig(save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print(f"Saved result image to {save_path}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if show: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 调整子图之间的距离,防止标题和标签重叠 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.tight_layout() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def show_box_or_line(imgs, pred, threshold, save_path = None, show_line=False, show_box=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    col = color() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    box_th, line_th = set_thresholds(threshold) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = imgs.permute(1, 2, 0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    boxes = pred[0]['boxes'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    box_scores = pred[0]['scores'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line_score = pred[-1]['wires']['score'].cpu().numpy()[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 可视化预测结 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fig, ax = plt.subplots(figsize=(10, 10)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.imshow(np.array(im)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if show_box: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for idx, box in enumerate(boxes): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if box_scores[idx] < box_th: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            x0, y0, x1, y1 = box 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.add_patch( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if save_path: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'box.png') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            os.makedirs(os.path.dirname(save_path), exist_ok=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            plt.savefig(save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(f"Saved result image to {save_path}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if show_line: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for idx, (a, b) in enumerate(line): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if line_score[idx] < line_th: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.scatter(a[1], a[0], c='#871F78', s=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.scatter(b[1], b[0], c='#871F78', s=2) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if save_path: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            save_path = os.path.join(datetime.now().strftime("%Y%m%d_%H%M%S"), 'line.png') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            os.makedirs(os.path.dirname(save_path), exist_ok=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            plt.savefig(save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            print(f"Saved result image to {save_path}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def show_predict(imgs, pred, threshold, t_start): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    col = color() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    box_th, line_th = set_thresholds(threshold) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = imgs.permute(1, 2, 0)  # 处理为 [512, 512, 3] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    boxes = pred[0]['boxes'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    box_scores = pred[0]['scores'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    lines = pred[0]['line'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    line_scores = pred[0]['line_score'].cpu().numpy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 可视化预测结 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fig, ax = plt.subplots(figsize=(10, 10)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.imshow(np.array(im)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    idx = 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        x0, y0, x1, y1 = box 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 框中无线的跳过 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if np.array_equal(line, tmp): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        a, b = line 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if box_score >= box_th or line_score >= line_th: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.add_patch( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.scatter(a[1], a[0], c='#871F78', s=10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.scatter(b[1], b[0], c='#871F78', s=10) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            idx = idx + 1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_end = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(f'predict used:{t_end - t_start}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.show() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def predict(pt_path, model, img, type=0, threshold=0.5, save_path=None, show=False): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model = load_best_model(model, pt_path, device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model.eval() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if isinstance(img, str): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        img = Image.open(img).convert("RGB") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    transform = transforms.ToTensor() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    img_tensor = transform(img)  # [3, 512, 512] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 将图像调整为512x512大小 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_start = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im = img_tensor.permute(1, 2, 0)  # [512, 512, 3] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    im_resized = skimage.transform.resize(im.cpu().numpy().astype(np.float32), (512, 512))  # (512, 512, 3) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    img_ = torch.tensor(im_resized).permute(2, 0, 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_end = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(f'switch img used:{t_end - t_start}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predictions = model([img_.to(device)]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # print(predictions) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_start = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    pred = box_line_(img_, predictions)  # 线框匹配 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_end = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(f'Matched boxes and lines used: {t_end - t_start:.4f} seconds') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if type == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        show_all(img_, pred, threshold, save_path=True, show=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    elif type == 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        show_box_or_line(img_, predictions, threshold, save_path=True, show_line=True)  # 参数确定画什么 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    elif type == 2: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        show_box_or_line(img_, predictions, threshold, save_path=True, show_box=True)  # 参数确定画什么 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    elif type == 3: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        show_predict(img_, pred, threshold, t_start) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_start = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(f'start to predict:{t_start}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model = linenet_resnet50_fpn().to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    pt_path = r'D:\python\PycharmProjects\20250214\weight\best.pth' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    img_path = r'C:\Users\m2337\Desktop\p\20250226142919.png' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # predict(pt_path, model, img_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    predict(pt_path, model, img_path, type=2, threshold=0.5, save_path=None, show=False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    t_end = time.time() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print(f'predict used:{t_end - t_start}') 
			 |