from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles import os import torch import numpy as np from PIL import Image import skimage.io import skimage.color from torchvision import transforms import shutil import matplotlib.pyplot as plt from models.line_detect.line_net import linenet_resnet50_fpn from models.line_detect.postprocess import postprocess from rtree import index import time import multiprocessing as mp from fastapi.middleware.cors import CORSMiddleware # ??? FastAPI app = FastAPI() # ?? CORS ??? app.add_middleware( CORSMiddleware, allow_origins=["*"], # ????? allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ?????????? 'spawn' mp.set_start_method('spawn', force=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(model_path): """???????????""" model = linenet_resnet50_fpn().to(device) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model from {model_path}") else: raise FileNotFoundError(f"No saved model found at {model_path}") model.eval() return model def preprocess_image(image_path): """????????""" img = Image.open(image_path).convert("RGB") transform = transforms.ToTensor() img_tensor = transform(img) resized_img = skimage.transform.resize( img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512) ) return torch.tensor(resized_img).permute(2, 0, 1) def save_plot(output_path: str): """?????????""" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) print(f"Saved plot to {output_path}") plt.close() def get_colors(): """????????????""" return [ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5', '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3', '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5', '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3', '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b', '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173', '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc', '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6', '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32', '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4', '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4', '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d', '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9', '#bfbfbf', '#969696', '#737373', '#525252', '#252525', '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c', '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026', '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072' ] # def process_box(box, lines, scores): # """?????????????????""" # valid_lines = [] # ??????? # valid_scores = [] # ??????? # for i in box: # best_line = None # max_length = 0.0 # # ?????? # for j in range(lines.shape[1]): # line_j = lines[0, j].cpu().numpy() / 128 * 512 # # ?????????box? # if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and # line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and # line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and # line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): # # ?????? # length = np.linalg.norm(line_j[0] - line_j[1]) # if length > max_length: # best_line = line_j # max_length = length # # ????????????????? # if best_line is not None: # valid_lines.append(best_line) # valid_scores.append(max_length) # ?????????? # return valid_lines, valid_scores def process_box(box, lines, scores): """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î""" valid_lines = [] # ´æ´¢ÓÐЧµÄÏß¶Î valid_scores = [] # ´æ´¢ÓÐЧµÄ·ÖÊý # print(f'score:{len(scores)}') for i in box: best_line = None max_length = 0.0 # ±éÀúËùÓÐÏß¶Î for j in range(lines.shape[1]): line_j = lines[0, j].cpu().numpy() / 128 * 512 # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and line_j[0][0] <= i[3] and line_j[1][0] <= i[3]): # length = np.linalg.norm(line_j[0] - line_j[1]) length = scores[j].item() # print(length) if length > max_length: best_line = line_j max_length = length if best_line is not None: valid_lines.append(best_line) valid_scores.append(max_length) else: valid_lines.append([[0.0,0.0],[0.0,0.0]]) valid_scores.append(0.0) print(f'valid_lines:{valid_lines}') # print(f'valid_scores:{valid_scores}') return valid_lines, valid_scores def box_line_optimized_parallel(pred): """?????????????""" lines = pred[-1]['wires']['lines'] # ???[1, 2500, 2, 2] scores = pred[-1]['wires']['score'][0] # ?????[2500] boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # ??box num_processes = min(mp.cpu_count(), len(boxes)) # ???????? with mp.Pool(processes=num_processes) as pool: results = pool.starmap( process_box, [(box, lines, scores) for box in boxes] ) # ?????? filtered_pred = [] for idx_box, (valid_lines, valid_scores) in enumerate(results): if valid_lines: pred[idx_box]['line'] = torch.tensor(valid_lines) pred[idx_box]['line_score'] = torch.tensor(valid_scores) filtered_pred.append(pred[idx_box]) return filtered_pred @app.get("/") def read_root(): """??????""" return FileResponse("static/index.html") @app.post("/predict") @app.post("/predict/") async def predict(file: UploadFile = File(...)): try: start_time = time.time() # ?????? os.makedirs("uploaded_images", exist_ok=True) image_path = f"uploaded_images/{file.filename}" with open(image_path, "wb") as f: shutil.copyfileobj(file.file, f) # ???? model_path = "/data/share/rlq/weights/linenet_wts/resnet50_best_e280.pth" model = load_model(model_path) # ????? img_tensor = preprocess_image(image_path) im = img_tensor.permute(1, 2, 0).cpu().numpy() # ???? with torch.no_grad(): predictions = model([img_tensor.to(device)]) # ???????? t_start = time.time() filtered_pred = box_line_optimized_parallel(predictions) t_end = time.time() print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds') # ???? output_path_box = show_box(im, predictions, t_start) output_path_line = show_line(im, predictions, t_start) output_path_boxandline = show_predict(im, filtered_pred, t_start) # # ???? # combined_image_path = "combined_result.png" # combine_images( # [output_path_boxandline], # # [output_path_boxandline, output_path_box, output_path_line], # titles=["Box and Line"], # output_path=combined_image_path # ) end_time = time.time() print(f'Total time: {end_time - start_time:.2f} seconds') # ??????? return FileResponse(output_path_boxandline, media_type="image/png", filename="result.png") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def combine_images(image_paths: list, titles: list, output_path: str): """????????????""" fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5)) for ax, img_path, title in zip(axes, image_paths, titles): ax.imshow(plt.imread(img_path)) ax.set_title(title) ax.axis("off") plt.savefig(output_path, bbox_inches="tight", pad_inches=0) plt.close() def show_box(im, predictions, t_start): """??????????""" boxes = predictions[0]['boxes'].cpu().numpy() box_scores = predictions[0]['scores'].cpu().numpy() colors = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for idx, (box, score) in enumerate(zip(boxes, box_scores)): if score < 0.7: continue x0, y0, x1, y1 = box ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1)) t_end = time.time() print(f'show_box used: {t_end - t_start:.2f} seconds') output_path = "temp_result_box.png" save_plot(output_path) return output_path def show_line(im, predictions, t_start): """?????????""" lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512 line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0] diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5 nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False) fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) for (a, b), s in zip(nlines, nscores): if s < 0.9: continue ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2) ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1) t_end = time.time() print(f'show_line used: {t_end - t_start:.2f} seconds') output_path = "temp_result_line.png" save_plot(output_path) return output_path def show_predict(im, pred, t_start): """?????????????????""" col = get_colors() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(im) boxes = pred[0]['boxes'].cpu().numpy() box_scores = pred[0]['scores'].cpu().numpy() lines = pred[0]['line'].cpu().numpy() line_scores = pred[0]['line_score'].cpu().numpy() idx = 0 tmp = np.array([[0.0, 0.0], [0.0, 0.0]]) for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores): x0, y0, x1, y1 = box if np.array_equal(line, tmp): continue a, b = line if box_score >= 0.7 or line_score >= 0.9: ax.add_patch( plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1)) ax.scatter(a[1], a[0], c='#871F78', s=10) ax.scatter(b[1], b[0], c='#871F78', s=10) ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1) idx = idx + 1 t_end = time.time() print(f'show_predict used: {t_end - t_start:.2f} seconds') output_path = "temp_result.png" save_plot(output_path) return output_path if __name__ == "__main__": import uvicorn uvicorn.run(app, host="192.168.50.100", port=808, log_level="debug")