| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- # from fastapi import FastAPI, File, UploadFile, HTTPException
- # from fastapi.responses import FileResponse
- # from fastapi.staticfiles import StaticFiles
- import os
- import torch
- import numpy as np
- from PIL import Image
- import skimage.io
- import skimage.color
- from torchvision import transforms
- import shutil
- import matplotlib.pyplot as plt
- from models.line_net.line_net import linenet_resnet50_fpn
- from models.wirenet.postprocess import postprocess
- from rtree import index
- import time
- import multiprocessing as mp
- # from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_net.boxline import show_box
- # ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
- mp.set_start_method('spawn', force=True)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def load_model(model_path):
- """¼ÓÔØÄ£ÐͲ¢·µ»ØÄ£ÐÍʵÀý"""
- model = linenet_resnet50_fpn().to(device)
- if os.path.exists(model_path):
- checkpoint = torch.load(model_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- print(f"Loaded model from {model_path}")
- else:
- raise FileNotFoundError(f"No saved model found at {model_path}")
- model.eval()
- return model
- def preprocess_image(image_path):
- """Ô¤´¦ÀíÉÏ´«µÄͼƬ"""
- img = Image.open(image_path).convert("RGB")
- transform = transforms.ToTensor()
- img_tensor = transform(img)
- resized_img = skimage.transform.resize(
- img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
- )
- return torch.tensor(resized_img).permute(2, 0, 1),img
- def save_plot(output_path: str):
- """±£´æÍ¼Ïñ²¢¹Ø±Õ»æÍ¼"""
- plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
- print(f"Saved plot to {output_path}")
- plt.close()
- def get_colors():
- """·µ»ØÒ»×éÔ¤¶¨ÒåµÄÑÕÉ«Áбí"""
- return [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- def process_box(box, lines, scores):
- """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
- valid_lines = [] # ´æ´¢ÓÐЧµÄÏß¶Î
- valid_scores = [] # ´æ´¢ÓÐЧµÄ·ÖÊý
- # print(f'score:{len(scores)}')
- for i in box:
- best_line = None
- max_length = 0.0
- # ±éÀúËùÓÐÏß¶Î
- for j in range(lines.shape[1]):
- # line_j = lines[0, j].cpu().numpy() / 128 * 512
- line_j = lines[0, j].cpu().numpy()
- # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
- if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
- line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
- line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
- line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
- # ¼ÆËãÏ߶㤶È
- length = np.linalg.norm(line_j[0] - line_j[1])
- # length = scores[j].cpu().numpy()
- # print(length)
- if length > max_length:
- best_line = line_j
- max_length = length
- # Èç¹ûÕÒµ½ÓÐЧµÄÏ߶Σ¬ÔòÌí¼Óµ½½á¹ûÖÐ
- if best_line is not None:
- valid_lines.append(best_line)
- valid_scores.append(max_length) # ʹÓÃÏ߶㤶È×÷Ϊ·ÖÊý
- else:
- valid_lines.append([[0.0,0.0],[0.0,0.0]])
- valid_scores.append(0.0) # ʹÓÃÏß¶ÎÖÃÐŶÈ×÷Ϊ·ÖÊý
- # print(f'valid_lines:{valid_lines}')
- # print(f'valid_scores:{valid_scores}')
- return valid_lines, valid_scores
- # def box_line_optimized_parallel(imgs, pred): # ĬÈÏÖÃÐŶÈ
- # im = imgs.permute(1, 2, 0).cpu().numpy()
- # line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- # line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
- #
- # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
- # for idx, box_ in enumerate(pred[0:-1]):
- # box = box_['boxes']
- #
- # line_ = []
- # score_ = []
- #
- # for i in box:
- # score_max = 0.0
- # tmp = [[0.0, 0.0], [0.0, 0.0]]
- #
- # for j in range(len(line)):
- # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
- # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
- # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
- # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
- #
- # if score[j] > score_max:
- # tmp = line[j]
- # score_max = score[j]
- # line_.append(tmp)
- # score_.append(score_max)
- # processed_list = torch.tensor(line_)
- # pred[idx]['line'] = processed_list
- #
- # processed_s_list = torch.tensor(score_)
- # pred[idx]['line_score'] = processed_s_list
- # del pred[-1]
- # return pred
- def box_line_optimized_parallel(imgs, pred, length=False): # 默认置信度
- im = imgs.permute(1, 2, 0).cpu().numpy()
- line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
- # print(f'line_data:{line_data}')
- points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
- is_all_zeros = np.all(line_data == 0.0)
- if is_all_zeros:
- for idx, box_ in enumerate(pred[0:-1]):
- score_max = 0.0
- tmp = [[0.0, 0.0], [0.0, 0.0]]
- processed_list = torch.tensor(tmp)
- pred[idx]['line'] = processed_list
- processed_s_list = torch.tensor(score_max)
- pred[idx]['line_score'] = processed_s_list
- else:
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
- for idx, box_ in enumerate(pred[0:-2]):
- box = box_['boxes'] # 是一个tensor
- line_ = []
- score_ = []
- for i in box:
- score_max = 0.0
- tmp = [[0.0, 0.0], [0.0, 0.0]]
- for j in range(len(line)):
- if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
- line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
- line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
- line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
- if score[j] > score_max:
- tmp = line[j]
- score_max = score[j]
- # 如果 box 内无线段,则通过点坐标找最长线段
- if score_max == 0.0: # 说明 box 内无线段
- box_points = [
- [x, y] for x, y in points
- if i[0] <= y <= i[2] and i[1] <= x <= i[3]
- ]
- if len(box_points) >= 2: # 至少需要两个点才能组成线段
- max_distance = 0.0
- longest_segment = [[0.0, 0.0], [0.0, 0.0]]
- # 找出 box 内点组成的最长线段
- for p1 in box_points:
- for p2 in box_points:
- if p1 != p2:
- distance = np.linalg.norm(np.array(p1) - np.array(p2))
- if distance > max_distance:
- max_distance = distance
- longest_segment = [p1, p2]
- tmp = longest_segment
- score_max = 0.0 # 默认分数为 0.0
- line_.append(tmp)
- score_.append(score_max)
- processed_list = torch.tensor(line_)
- pred[idx]['line'] = processed_list
- processed_s_list = torch.tensor(score_)
- pred[idx]['line_score'] = processed_s_list
- return pred
- def show_predict1(imgs, pred, t_start):
- col = [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- im = imgs.permute(1, 2, 0) # ´¦ÀíΪ [512, 512, 3]
- boxes = pred[0]['boxes'].cpu().numpy()
- box_scores = pred[0]['scores'].cpu().numpy()
- lines = pred[0]['line'].cpu().numpy()
- line_scores = pred[0]['line_score'].cpu().numpy()
- # ¿ÉÊÓ»¯Ô¤²â½á
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(np.array(im))
- idx = 0
- tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
- for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
- x0, y0, x1, y1 = box
- # ¿òÖÐÎÞÏßµÄÌø¹ý
- if np.array_equal(line, tmp):
- continue
- a, b = line
- if box_score >= 0 or line_score >= 0:
- ax.add_patch(
- plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
- ax.scatter(a[1], a[0], c='#871F78', s=10)
- ax.scatter(b[1], b[0], c='#871F78', s=10)
- ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
- idx = idx + 1
- t_end = time.time()
- print(f'predict used:{t_end - t_start}')
- plt.savefig("temp_result.png")
- plt.show()
- # output_path = "temp_result.png"
- # save_plot(output_path)
- # return output_path
- def predict(image_path):
- start_time = time.time()
- model_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth"
- model = load_model(model_path)
- img_tensor,_ = preprocess_image(image_path)
- print(f'img shape:{img_tensor.shape}')
- # Ä£ÐÍÍÆÀí
- with torch.no_grad():
- predictions = model([img_tensor.to(device)])
- print(f'predictions[0]:{predictions[1][0].shape}') # 第2个是特征图 [1,256,128,128]
- plt.imshow(predictions[1][0][2].cpu())
- plt.show()
- # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
- # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328])
- '''
- start_time1 = time.time()
- show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
- show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions)
- H = predictions[-1]['wires']
- lines = H["lines"][0].cpu().numpy() / 128 * np.array([2112, 1328])
- scores = H["score"][0].cpu().numpy()
- for i in range(1, len(lines)):
- if (lines[i] == lines[0]).all():
- lines = lines[:i]
- scores = scores[:i]
- break
- # postprocess lines to remove overlapped lines
- diag = (512 ** 2 + 512 ** 2) ** 0.5
- nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
- # lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328])
- print(f'线段 len:{len(nlines)}')
- # print(f"Initial lines shape: {lines.shape}")
- # print(f"Initial lines data type: {lines.dtype}")
- formatted_lines = []
- for line in nlines:
- if (line == [[0.0, 0.0], [0.0, 0.0]]).all():
- continue
- #line=[[500.0, 500.0], [650.0, 650.0]]
- start_point = np.array([line[0][0], line[0][1]])
- end_point = np.array([line[1][0], line[1][1]])
- formatted_lines.append([start_point, end_point])
- formatted_lines = np.array(formatted_lines)
- print(f"Final formatted_lines shape: {formatted_lines.shape}")
- print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
- # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array]
- result = [formatted_lines]
- print(f"Final result type: {type(result)}")
- print(f"Final result[0] shape: {result[0].shape}")
- return result
- '''
- def show_line(im, predictions, start_time1):
- """»æÖÆÏ߶β¢±£´æ½á¹û"""
- lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
- is_all_zeros = np.all(lines == 0.0)
- if is_all_zeros:
- fig, ax = plt.subplots(figsize=(10, 10))
- t_end = time.time()
- plt.savefig("temp_line.png")
- else:
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for (a, b), s in zip(nlines, nscores):
- if s < 0:
- continue
- ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- t_end = time.time()
- plt.savefig("temp_line.png")
- print(f'show line time:{t_end-start_time1}')
- # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- # fig, ax = plt.subplots(figsize=(10, 10))
- # ax.imshow(im)
- # for (a, b), s in zip(nlines, nscores):
- # if s < 0:
- # continue
- # ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
- # ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- # t_end = time.time()
- # plt.savefig("temp_line.png")
- def show_box(im, predictions):
- """绘制边界框并保存结果"""
- boxes = predictions[0]['boxes'].cpu().numpy()
- box_scores = predictions[0]['scores'].cpu().numpy()
- colors = get_colors()
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for idx, (box, score) in enumerate(zip(boxes, box_scores)):
- if score < 0:
- continue
- x0, y0, x1, y1 = box
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
- t_end = time.time()
- plt.savefig("temp_box.png")
- def show_predict(im, filtered_pred, t_start):
- """»æÖÆÆ¥ÅäºóµÄ±ß½ç¿òºÍÏ߶β¢±£´æ½á¹û"""
- colors = get_colors()
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- pred = filtered_pred[0]
- boxes = pred['boxes'].cpu().numpy()
- box_scores = pred['scores'].cpu().numpy()
- lines = pred['line'].cpu().numpy()
- line_scores = pred['line_score'].cpu().numpy()
- print("Boxes:", pred['boxes'])
- print("Lines:", pred['line'])
- print("Line scores:", pred['line_score'])
- is_all_zeros = np.all(lines == 0.0)
- if not is_all_zeros:
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
- if box_score < 0.0 or line_score < 0.0:
- continue
- # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
- if line is None or len(line) == 0:
- continue
- x0, y0, x1, y1 = box
- a, b = line
- color = colors[box_idx % len(colors)] # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
- ax.scatter(a[1], a[0], c=color, s=10)
- ax.scatter(b[1], b[0], c=color, s=10)
- ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
- # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
- # if box_score < 0.0 or line_score < 0.0:
- # continue
- #
- # # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
- # if line is None or len(line) == 0:
- # continue
- #
- # x0, y0, x1, y1 = box
- # a, b = line
- # color = colors[(idx + box_idx) % len(colors)] # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
- # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
- # ax.scatter(a[1], a[0], c=color, s=10)
- # ax.scatter(b[1], b[0], c=color, s=10)
- # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
- t_end = time.time()
- # plt.show()
- print(f'show_predict used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result.png"
- save_plot(output_path)
- return output_path
- if __name__ == "__main__":
- lines = predict(r"C:\Users\m2337\Desktop\p\140502.png")
- print(f'lines:{lines}')
|