# from fastapi import FastAPI, File, UploadFile, HTTPException # from fastapi.responses import FileResponse # from fastapi.staticfiles import StaticFiles import os import torch import numpy as np from PIL import Image import skimage.io import skimage.color from torchvision import transforms import shutil import matplotlib.pyplot as plt from models.line_net.line_net import linenet_resnet50_fpn from models.wirenet.postprocess import postprocess from rtree import index import time import multiprocessing as mp # from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_net.boxline import show_box # ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn' mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(model_path): """¼ÓÔØÄ£ÐͲ¢·µ»ØÄ£ÐÍʵÀý""" model = linenet_resnet50_fpn().to(device) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from {model_path}") else: raise FileNotFoundError(f"No saved model found at {model_path}") model.eval() return model def preprocess_image(image_path): """Ô¤´¦ÀíÉÏ´«µÄͼƬ""" img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) resized_img = skimage.transform.resize( img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512) ) return torch.tensor(resized_img).permute(2, 0, 1),img def save_plot(output_path: str): """±£´æÍ¼Ïñ²¢¹Ø±Õ»æÍ¼""" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) print(f"Saved plot to {output_path}") plt.close() def get_colors(): """·µ»ØÒ»×éÔ¤¶¨ÒåµÄÑÕÉ«Áбí""" return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] def process_box(box, lines, scores): """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î""" valid_lines = [] # ´æ´¢ÓÐЧµÄÏß¶Î valid_scores = [] # ´æ´¢ÓÐЧµÄ·ÖÊý # print(f'score:{len(scores)}') for i in box: best_line = None max_length = 0.0 # ±éÀúËùÓÐÏß¶Î for j in range(lines.shape[1]): # line_j = lines[0, j].cpu().numpy() / 128 * 512 line_j = lines[0, j].cpu().numpy() # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): # ¼ÆËãÏ߶γ¤¶È length = np.linalg.norm(line_j[0] - line_j[1]) # length = scores[j].cpu().numpy() # print(length) if length > max_length: best_line = line_j max_length = length # Èç¹ûÕÒµ½ÓÐЧµÄÏ߶Σ¬ÔòÌí¼Óµ½½á¹ûÖÐ if best_line is not None: valid_lines.append(best_line) valid_scores.append(max_length) # ʹÓÃÏ߶㤶È×÷Ϊ·ÖÊý else: valid_lines.append([[0.0,0.0],[0.0,0.0]]) valid_scores.append(0.0) # ʹÓÃÏß¶ÎÖÃÐŶÈ×÷Ϊ·ÖÊý # print(f'valid_lines:{valid_lines}') # print(f'valid_scores:{valid_scores}') return valid_lines, valid_scores # def box_line_optimized_parallel(imgs, pred): # ĬÈÏÖÃÐÅ¶È # im = imgs.permute(1, 2, 0).cpu().numpy() # line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 # line_scores = pred[-1]['wires']['score'].cpu().numpy()[0] # # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False) # for idx, box_ in enumerate(pred[0:-1]): # box = box_['boxes'] # # line_ = [] # score_ = [] # # for i in box: # score_max = 0.0 # tmp = [[0.0, 0.0], [0.0, 0.0]] # # for j in range(len(line)): # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]): # # if score[j] > score_max: # tmp = line[j] # score_max = score[j] # line_.append(tmp) # score_.append(score_max) # processed_list = torch.tensor(line_) # pred[idx]['line'] = processed_list # # processed_s_list = torch.tensor(score_) # pred[idx]['line_score'] = processed_s_list # del pred[-1] # return pred def box_line_optimized_parallel(imgs, pred, length=False): # 默认置信度 im = imgs.permute(1, 2, 0).cpu().numpy() line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = pred[-1]['wires']['score'].cpu().numpy()[0] # print(f'line_data:{line_data}') points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512 is_all_zeros = np.all(line_data == 0.0) if is_all_zeros: for idx, box_ in enumerate(pred[0:-1]): score_max = 0.0 tmp = [[0.0, 0.0], [0.0, 0.0]] processed_list = torch.tensor(tmp) pred[idx]['line'] = processed_list processed_s_list = torch.tensor(score_max) pred[idx]['line_score'] = processed_s_list else: diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False) for idx, box_ in enumerate(pred[0:-2]): box = box_['boxes'] # 是一个tensor line_ = [] score_ = [] for i in box: score_max = 0.0 tmp = [[0.0, 0.0], [0.0, 0.0]] for j in range(len(line)): if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and line[j][0][0] <= i[3] and line[j][1][0] <= i[3]): if score[j] > score_max: tmp = line[j] score_max = score[j] # 如果 box 内无线段,则通过点坐标找最长线段 if score_max == 0.0: # 说明 box 内无线段 box_points = [ [x, y] for x, y in points if i[0] <= y <= i[2] and i[1] <= x <= i[3] ] if len(box_points) >= 2: # 至少需要两个点才能组成线段 max_distance = 0.0 longest_segment = [[0.0, 0.0], [0.0, 0.0]] # 找出 box 内点组成的最长线段 for p1 in box_points: for p2 in box_points: if p1 != p2: distance = np.linalg.norm(np.array(p1) - np.array(p2)) if distance > max_distance: max_distance = distance longest_segment = [p1, p2] tmp = longest_segment score_max = 0.0 # 默认分数为 0.0 line_.append(tmp) score_.append(score_max) processed_list = torch.tensor(line_) pred[idx]['line'] = processed_list processed_s_list = torch.tensor(score_) pred[idx]['line_score'] = processed_s_list return pred def show_predict1(imgs, pred, t_start): col = [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] im = imgs.permute(1, 2, 0) # ´¦ÀíΪ [512, 512, 3] boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[0]['line'].cpu().numpy() line_scores = pred[0]['line_score'].cpu().numpy() # ¿ÉÊÓ»¯Ô¤²â½á fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(np.array(im)) idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): x0, y0, x1, y1 = box # ¿òÖÐÎÞÏßµÄÌø¹ý if np.array_equal(line, tmp): continue a, b = line if box_score >= 0 or line_score >= 0: ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) ax.scatter(a[1], a[0], c='#871F78', s=10) ax.scatter(b[1], b[0], c='#871F78', s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 t_end = time.time() print(f'predict used:{t_end - t_start}') plt.savefig("temp_result.png") plt.show() # output_path = "temp_result.png" # save_plot(output_path) # return output_path def predict(image_path): start_time = time.time() model_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth" model = load_model(model_path) img_tensor,_ = preprocess_image(image_path) print(f'img shape:{img_tensor.shape}') # Ä£ÐÍÍÆÀí with torch.no_grad(): predictions = model([img_tensor.to(device)]) print(f'predictions[0]:{predictions[1][0].shape}') # 第2个是特征图 [1,256,128,128] plt.imshow(predictions[1][0][2].cpu()) plt.show() # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}') # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328]) ''' start_time1 = time.time() show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1) show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions) H = predictions[-1]['wires'] lines = H["lines"][0].cpu().numpy() / 128 * np.array([2112, 1328]) scores = H["score"][0].cpu().numpy() for i in range(1, len(lines)): if (lines[i] == lines[0]).all(): lines = lines[:i] scores = scores[:i] break # postprocess lines to remove overlapped lines diag = (512 ** 2 + 512 ** 2) ** 0.5 nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False) # lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328]) print(f'线段 len:{len(nlines)}') # print(f"Initial lines shape: {lines.shape}") # print(f"Initial lines data type: {lines.dtype}") formatted_lines = [] for line in nlines: if (line == [[0.0, 0.0], [0.0, 0.0]]).all(): continue #line=[[500.0, 500.0], [650.0, 650.0]] start_point = np.array([line[0][0], line[0][1]]) end_point = np.array([line[1][0], line[1][1]]) formatted_lines.append([start_point, end_point]) formatted_lines = np.array(formatted_lines) print(f"Final formatted_lines shape: {formatted_lines.shape}") print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}") # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array] result = [formatted_lines] print(f"Final result type: {type(result)}") print(f"Final result[0] shape: {result[0].shape}") return result ''' def show_line(im, predictions, start_time1): """»æÖÆÏ߶β¢±£´æ½á¹û""" lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] is_all_zeros = np.all(lines == 0.0) if is_all_zeros: fig, ax = plt.subplots(figsize=(10, 10)) t_end = time.time() plt.savefig("temp_line.png") else: diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for (a, b), s in zip(nlines, nscores): if s < 0: continue ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) t_end = time.time() plt.savefig("temp_line.png") print(f'show line time:{t_end-start_time1}') # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) # fig, ax = plt.subplots(figsize=(10, 10)) # ax.imshow(im) # for (a, b), s in zip(nlines, nscores): # if s < 0: # continue # ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) # ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) # t_end = time.time() # plt.savefig("temp_line.png") def show_box(im, predictions): """绘制边界框并保存结果""" boxes = predictions[0]['boxes'].cpu().numpy() box_scores = predictions[0]['scores'].cpu().numpy() colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, (box, score) in enumerate(zip(boxes, box_scores)): if score < 0: continue x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1)) t_end = time.time() plt.savefig("temp_box.png") def show_predict(im, filtered_pred, t_start): """»æÖÆÆ¥ÅäºóµÄ±ß½ç¿òºÍÏ߶β¢±£´æ½á¹û""" colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) pred = filtered_pred[0] boxes = pred['boxes'].cpu().numpy() box_scores = pred['scores'].cpu().numpy() lines = pred['line'].cpu().numpy() line_scores = pred['line_score'].cpu().numpy() print("Boxes:", pred['boxes']) print("Lines:", pred['line']) print("Line scores:", pred['line_score']) is_all_zeros = np.all(lines == 0.0) if not is_all_zeros: diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)): if box_score < 0.0 or line_score < 0.0: continue # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ if line is None or len(line) == 0: continue x0, y0, x1, y1 = box a, b = line color = colors[box_idx % len(colors)] # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ« ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1)) ax.scatter(a[1], a[0], c=color, s=10) ax.scatter(b[1], b[0], c=color, s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1) # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)): # if box_score < 0.0 or line_score < 0.0: # continue # # # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ # if line is None or len(line) == 0: # continue # # x0, y0, x1, y1 = box # a, b = line # color = colors[(idx + box_idx) % len(colors)] # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ« # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1)) # ax.scatter(a[1], a[0], c=color, s=10) # ax.scatter(b[1], b[0], c=color, s=10) # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1) t_end = time.time() # plt.show() print(f'show_predict used: {t_end - t_start:.2f} seconds') output_path = "temp_result.png" save_plot(output_path) return output_path if __name__ == "__main__": lines = predict(r"C:\Users\m2337\Desktop\p\140502.png") print(f'lines:{lines}')