123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- # zjf的是将线段长度作为得分,这里是将最大置信度作为得分
- # from fastapi import FastAPI, File, UploadFile, HTTPException
- # from fastapi.responses import FileResponse
- # from fastapi.staticfiles import StaticFiles
- import os
- import torch
- import numpy as np
- from PIL import Image
- import skimage.io
- import skimage.color
- from torchvision import transforms
- import shutil
- import matplotlib.pyplot as plt
- from models.line_detect.line_net import linenet_resnet50_fpn
- from models.line_detect.postprocess import postprocess
- from rtree import index
- import time
- import multiprocessing as mp
- # from fastapi.middleware.cors import CORSMiddleware
- # 初始化 FastAPI
- # app = FastAPI()
- # 添加 CORS 中间件
- # app.add_middleware(
- # CORSMiddleware,
- # allow_origins=["*"], # 允许所有源
- # allow_credentials=True,
- # allow_methods=["*"],
- # allow_headers=["*"],
- # )
- # 设置多进程启动方式为 'spawn'
- mp.set_start_method('spawn', force=True)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- def load_model(model_path):
- """加载模型并返回模型实例"""
- model = linenet_resnet50_fpn().to(device)
- if os.path.exists(model_path):
- checkpoint = torch.load(model_path, map_location=device)
- model.load_state_dict(checkpoint['model_state_dict'])
- print(f"Loaded model from {model_path}")
- else:
- raise FileNotFoundError(f"No saved model found at {model_path}")
- model.eval()
- return model
- def preprocess_image(image_path):
- """预处理上传的图片"""
- img = Image.open(image_path).convert("RGB")
- transform = transforms.ToTensor()
- img_tensor = transform(img)
- resized_img = skimage.transform.resize(
- img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
- )
- return torch.tensor(resized_img).permute(2, 0, 1)
- def save_plot(output_path: str):
- """保存图像并关闭绘图"""
- plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
- print(f"Saved plot to {output_path}")
- plt.close()
- def get_colors():
- """返回一组预定义的颜色列表"""
- return [
- '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
- '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
- '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
- '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
- '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
- '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
- '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
- '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
- '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
- '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
- '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
- '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
- '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
- '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
- '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
- '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
- '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
- '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
- '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
- '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
- ]
- def process_box(box, lines, scores):
- """处理单个边界框,找到最佳匹配的线段"""
- valid_lines = [] # 存储有效的线段
- valid_scores = [] # 存储有效的分数
- # print(f'score:{len(scores)}')
- for i in box:
- best_line = None
- max_length = 0.0
- # 遍历所有线段
- for j in range(lines.shape[1]):
- line_j = lines[0, j].cpu().numpy() / 128 * 512
- # 检查线段是否完全在box内
- if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
- line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
- line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
- line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
- # length = np.linalg.norm(line_j[0] - line_j[1])
- length = scores[j].item()
- # print(length)
- if length > max_length:
- best_line = line_j
- max_length = length
- if best_line is not None:
- valid_lines.append(best_line)
- valid_scores.append(max_length)
- else:
- valid_lines.append([[0.0,0.0],[0.0,0.0]])
- valid_scores.append(0.0) # 使用线段置信度作为分数
- # print(f'valid_lines:{valid_lines}')
- # print(f'valid_scores:{valid_scores}')
- return valid_lines, valid_scores
- def box_line_optimized_parallel(pred):
- """并行处理边界框和线段的匹配"""
- lines = pred[-1]['wires']['lines'] # 形状为[1, 2500, 2, 2]
- scores = pred[-1]['wires']['score'][0] # 假设形状为[2500]
- boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # 所有box
- num_processes = min(mp.cpu_count(), len(boxes)) # 使用可用的核心数
- with mp.Pool(processes=num_processes) as pool:
- results = pool.starmap(
- process_box,
- [(box, lines, scores) for box in boxes]
- )
- # 更新预测结果
- filtered_pred = []
- for idx_box, (valid_lines, valid_scores) in enumerate(results):
- if valid_lines:
- pred[idx_box]['line'] = torch.tensor(valid_lines)
- pred[idx_box]['line_score'] = torch.tensor(valid_scores)
- filtered_pred.append(pred[idx_box])
- return filtered_pred
- def predict(image_path):
- start_time = time.time()
- # 保存上传文件
- # os.makedirs("uploaded_images", exist_ok=True)
- # image_path = f"{file.filename}"
- # with open(image_path, "wb") as f:
- # shutil.copyfileobj(file.file, f)
- # 加载模型
- model_path = "/data/lm/resnet50_best_e280.pth"
- model = load_model(model_path)
- # 预处理图片
- img_tensor = preprocess_image(image_path)
- im = img_tensor.permute(1, 2, 0).cpu().numpy()
- # 模型推理
- with torch.no_grad():
- predictions = model([img_tensor.to(device)])
- # 匹配线段和边界框
- t_start = time.time()
- filtered_pred = box_line_optimized_parallel(predictions)
- t_end = time.time()
- print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
- # 绘制图像
- output_path_box = show_box(im, predictions, t_start)
- output_path_line = show_line(im, predictions, t_start)
- output_path_boxandline = show_predict(im, filtered_pred, t_start)
- # 合并图像
- combined_image_path = "combined_result.png"
- combine_images(
- [output_path_boxandline, output_path_box, output_path_line],
- titles=["Box and Line", "Box", "Line"],
- output_path=combined_image_path
- )
- end_time = time.time()
- print(f'Total time: {end_time - start_time:.2f} seconds')
- # # 返回合成的图片
- # return FileResponse(combined_image_path, media_type="image/png", filename="combined_result.png")
- # except Exception as e:
- # raise HTTPException(status_code=500, detail=str(e))
- def combine_images(image_paths: list, titles: list, output_path: str):
- """将多个图像合并为一张图片"""
- fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
- for ax, img_path, title in zip(axes, image_paths, titles):
- ax.imshow(plt.imread(img_path))
- ax.set_title(title)
- ax.axis("off")
- plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
- plt.close()
- def show_box(im, predictions, t_start):
- """绘制边界框并保存结果"""
- boxes = predictions[0]['boxes'].cpu().numpy()
- box_scores = predictions[0]['scores'].cpu().numpy()
- colors = get_colors()
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for idx, (box, score) in enumerate(zip(boxes, box_scores)):
- if score < 0.7:
- continue
- x0, y0, x1, y1 = box
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
- t_end = time.time()
- print(f'show_box used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result_box.png"
- save_plot(output_path)
- return output_path
- def show_line(im, predictions, t_start):
- """绘制线段并保存结果"""
- lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
- line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for (a, b), s in zip(nlines, nscores):
- if s < 0.9:
- continue
- ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
- ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
- t_end = time.time()
- print(f'show_line used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result_line.png"
- save_plot(output_path)
- return output_path
- def show_predict(im, filtered_pred, t_start):
- """绘制匹配后的边界框和线段并保存结果"""
- colors = get_colors()
- fig, ax = plt.subplots(figsize=(10, 10))
- ax.imshow(im)
- for idx, pred in enumerate(filtered_pred):
- boxes = pred['boxes'].cpu().numpy()
- box_scores = pred['scores'].cpu().numpy()
- lines = pred['line'].cpu().numpy()
- line_scores = pred['line_score'].cpu().numpy()
- diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
- nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
- for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
- if box_score < 0.7 or line_score < 0.9:
- continue
- # 如果线段为空(即没有找到有效线段),跳过绘制
- if line is None or len(line) == 0:
- continue
- x0, y0, x1, y1 = box
- a, b = line
- color = colors[(idx + box_idx) % len(colors)] # 每个边界框分配一个唯一颜色
- ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
- ax.scatter(a[1], a[0], c=color, s=10)
- ax.scatter(b[1], b[0], c=color, s=10)
- ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
- t_end = time.time()
- print(f'show_predict used: {t_end - t_start:.2f} seconds')
- output_path = "temp_result.png"
- save_plot(output_path)
- return output_path
- if __name__ == "__main__":
- predict('/data/lm/wirenet_1000/images/train/00031643_1.png')
|