|
@@ -0,0 +1,320 @@
|
|
|
+from fastapi import FastAPI, File, UploadFile, HTTPException
|
|
|
+from fastapi.responses import FileResponse
|
|
|
+from fastapi.staticfiles import StaticFiles
|
|
|
+import os
|
|
|
+import torch
|
|
|
+import numpy as np
|
|
|
+from PIL import Image
|
|
|
+import skimage.io
|
|
|
+import skimage.color
|
|
|
+from torchvision import transforms
|
|
|
+import shutil
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+from models.line_detect.line_net import linenet_resnet50_fpn
|
|
|
+from models.line_detect.postprocess import postprocess
|
|
|
+from rtree import index
|
|
|
+import time
|
|
|
+import multiprocessing as mp
|
|
|
+from fastapi.middleware.cors import CORSMiddleware
|
|
|
+# ??? FastAPI
|
|
|
+app = FastAPI()
|
|
|
+
|
|
|
+# ?? CORS ???
|
|
|
+app.add_middleware(
|
|
|
+ CORSMiddleware,
|
|
|
+ allow_origins=["*"], # ?????
|
|
|
+ allow_credentials=True,
|
|
|
+ allow_methods=["*"],
|
|
|
+ allow_headers=["*"],
|
|
|
+)
|
|
|
+
|
|
|
+# ?????????? 'spawn'
|
|
|
+mp.set_start_method('spawn', force=True)
|
|
|
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
+
|
|
|
+def load_model(model_path):
|
|
|
+ """???????????"""
|
|
|
+ model = linenet_resnet50_fpn().to(device)
|
|
|
+ if os.path.exists(model_path):
|
|
|
+ checkpoint = torch.load(model_path, map_location=device)
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ print(f"Loaded model from {model_path}")
|
|
|
+ else:
|
|
|
+ raise FileNotFoundError(f"No saved model found at {model_path}")
|
|
|
+ model.eval()
|
|
|
+ return model
|
|
|
+
|
|
|
+def preprocess_image(image_path):
|
|
|
+ """????????"""
|
|
|
+ img = Image.open(image_path).convert("RGB")
|
|
|
+ transform = transforms.ToTensor()
|
|
|
+ img_tensor = transform(img)
|
|
|
+ resized_img = skimage.transform.resize(
|
|
|
+ img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
|
|
|
+ )
|
|
|
+ return torch.tensor(resized_img).permute(2, 0, 1)
|
|
|
+
|
|
|
+def save_plot(output_path: str):
|
|
|
+ """?????????"""
|
|
|
+ plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
|
|
|
+ print(f"Saved plot to {output_path}")
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+def get_colors():
|
|
|
+ """????????????"""
|
|
|
+ return [
|
|
|
+ '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
|
|
|
+ '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
|
|
|
+ '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
|
|
|
+ '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
|
|
|
+ '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
|
|
|
+ '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
|
|
|
+ '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
|
|
|
+ '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
|
|
|
+ '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
|
|
|
+ '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
|
|
|
+ '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
|
|
|
+ '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
|
|
|
+ '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
|
|
|
+ '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
|
|
|
+ '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
|
|
|
+ '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
|
|
|
+ '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
|
|
|
+ '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
|
|
|
+ '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
|
|
|
+ '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
|
|
|
+ ]
|
|
|
+
|
|
|
+# def process_box(box, lines, scores):
|
|
|
+# """?????????????????"""
|
|
|
+# valid_lines = [] # ???????
|
|
|
+# valid_scores = [] # ???????
|
|
|
+# for i in box:
|
|
|
+# best_line = None
|
|
|
+# max_length = 0.0
|
|
|
+# # ??????
|
|
|
+# for j in range(lines.shape[1]):
|
|
|
+# line_j = lines[0, j].cpu().numpy() / 128 * 512
|
|
|
+# # ?????????box?
|
|
|
+# if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
|
|
|
+# line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
|
|
|
+# line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
|
|
|
+# line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
|
|
|
+# # ??????
|
|
|
+# length = np.linalg.norm(line_j[0] - line_j[1])
|
|
|
+# if length > max_length:
|
|
|
+# best_line = line_j
|
|
|
+# max_length = length
|
|
|
+# # ?????????????????
|
|
|
+# if best_line is not None:
|
|
|
+# valid_lines.append(best_line)
|
|
|
+# valid_scores.append(max_length) # ??????????
|
|
|
+# return valid_lines, valid_scores
|
|
|
+def process_box(box, lines, scores):
|
|
|
+ """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
|
|
|
+ valid_lines = [] # ´æ´¢ÓÐЧµÄÏß¶Î
|
|
|
+ valid_scores = [] # ´æ´¢ÓÐЧµÄ·ÖÊý
|
|
|
+ # print(f'score:{len(scores)}')
|
|
|
+ for i in box:
|
|
|
+ best_line = None
|
|
|
+ max_length = 0.0
|
|
|
+ # ±éÀúËùÓÐÏß¶Î
|
|
|
+ for j in range(lines.shape[1]):
|
|
|
+ line_j = lines[0, j].cpu().numpy() / 128 * 512
|
|
|
+ # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
|
|
|
+ if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
|
|
|
+ line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
|
|
|
+ line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
|
|
|
+ line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
|
|
|
+ # length = np.linalg.norm(line_j[0] - line_j[1])
|
|
|
+ length = scores[j].item()
|
|
|
+ # print(length)
|
|
|
+ if length > max_length:
|
|
|
+ best_line = line_j
|
|
|
+ max_length = length
|
|
|
+ if best_line is not None:
|
|
|
+ valid_lines.append(best_line)
|
|
|
+ valid_scores.append(max_length)
|
|
|
+
|
|
|
+ else:
|
|
|
+ valid_lines.append([[0.0,0.0],[0.0,0.0]])
|
|
|
+ valid_scores.append(0.0)
|
|
|
+ # print(f'valid_lines:{valid_lines}')
|
|
|
+ # print(f'valid_scores:{valid_scores}')
|
|
|
+ return valid_lines, valid_scores
|
|
|
+
|
|
|
+def box_line_optimized_parallel(pred):
|
|
|
+ """?????????????"""
|
|
|
+ lines = pred[-1]['wires']['lines'] # ???[1, 2500, 2, 2]
|
|
|
+ scores = pred[-1]['wires']['score'][0] # ?????[2500]
|
|
|
+ boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # ??box
|
|
|
+ num_processes = min(mp.cpu_count(), len(boxes)) # ????????
|
|
|
+ with mp.Pool(processes=num_processes) as pool:
|
|
|
+ results = pool.starmap(
|
|
|
+ process_box,
|
|
|
+ [(box, lines, scores) for box in boxes]
|
|
|
+ )
|
|
|
+ # ??????
|
|
|
+ filtered_pred = []
|
|
|
+ for idx_box, (valid_lines, valid_scores) in enumerate(results):
|
|
|
+ if valid_lines:
|
|
|
+ pred[idx_box]['line'] = torch.tensor(valid_lines)
|
|
|
+ pred[idx_box]['line_score'] = torch.tensor(valid_scores)
|
|
|
+ filtered_pred.append(pred[idx_box])
|
|
|
+ return filtered_pred
|
|
|
+
|
|
|
+@app.get("/")
|
|
|
+def read_root():
|
|
|
+ """??????"""
|
|
|
+ return FileResponse("static/index.html")
|
|
|
+
|
|
|
+@app.post("/predict")
|
|
|
+@app.post("/predict/")
|
|
|
+async def predict(file: UploadFile = File(...)):
|
|
|
+ try:
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # ??????
|
|
|
+ os.makedirs("uploaded_images", exist_ok=True)
|
|
|
+ image_path = f"uploaded_images/{file.filename}"
|
|
|
+ with open(image_path, "wb") as f:
|
|
|
+ shutil.copyfileobj(file.file, f)
|
|
|
+
|
|
|
+ # ????
|
|
|
+ model_path = "/data/share/rlq/weights/linenet_wts/resnet50_best_e280.pth"
|
|
|
+ model = load_model(model_path)
|
|
|
+
|
|
|
+ # ?????
|
|
|
+ img_tensor = preprocess_image(image_path)
|
|
|
+ im = img_tensor.permute(1, 2, 0).cpu().numpy()
|
|
|
+
|
|
|
+ # ????
|
|
|
+ with torch.no_grad():
|
|
|
+ predictions = model([img_tensor.to(device)])
|
|
|
+
|
|
|
+ # ????????
|
|
|
+ t_start = time.time()
|
|
|
+ filtered_pred = box_line_optimized_parallel(predictions)
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
|
|
|
+
|
|
|
+ # ????
|
|
|
+ output_path_box = show_box(im, predictions, t_start)
|
|
|
+ output_path_line = show_line(im, predictions, t_start)
|
|
|
+ output_path_boxandline = show_predict(im, filtered_pred, t_start)
|
|
|
+
|
|
|
+ # # ????
|
|
|
+ # combined_image_path = "combined_result.png"
|
|
|
+ # combine_images(
|
|
|
+ # [output_path_boxandline],
|
|
|
+ # # [output_path_boxandline, output_path_box, output_path_line],
|
|
|
+ # titles=["Box and Line"],
|
|
|
+ # output_path=combined_image_path
|
|
|
+ # )
|
|
|
+
|
|
|
+ end_time = time.time()
|
|
|
+ print(f'Total time: {end_time - start_time:.2f} seconds')
|
|
|
+
|
|
|
+ # ???????
|
|
|
+ return FileResponse(output_path_boxandline, media_type="image/png", filename="result.png")
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=500, detail=str(e))
|
|
|
+
|
|
|
+def combine_images(image_paths: list, titles: list, output_path: str):
|
|
|
+ """????????????"""
|
|
|
+ fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
|
|
|
+ for ax, img_path, title in zip(axes, image_paths, titles):
|
|
|
+ ax.imshow(plt.imread(img_path))
|
|
|
+ ax.set_title(title)
|
|
|
+ ax.axis("off")
|
|
|
+ plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+def show_box(im, predictions, t_start):
|
|
|
+ """??????????"""
|
|
|
+ boxes = predictions[0]['boxes'].cpu().numpy()
|
|
|
+ box_scores = predictions[0]['scores'].cpu().numpy()
|
|
|
+ colors = get_colors()
|
|
|
+ fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
+ ax.imshow(im)
|
|
|
+ for idx, (box, score) in enumerate(zip(boxes, box_scores)):
|
|
|
+ if score < 0.7:
|
|
|
+ continue
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'show_box used: {t_end - t_start:.2f} seconds')
|
|
|
+ output_path = "temp_result_box.png"
|
|
|
+ save_plot(output_path)
|
|
|
+ return output_path
|
|
|
+
|
|
|
+def show_line(im, predictions, t_start):
|
|
|
+ """?????????"""
|
|
|
+ lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
|
|
|
+ line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
|
|
|
+ diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
|
|
|
+ fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
+ ax.imshow(im)
|
|
|
+ for (a, b), s in zip(nlines, nscores):
|
|
|
+ if s < 0.9:
|
|
|
+ continue
|
|
|
+ ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'show_line used: {t_end - t_start:.2f} seconds')
|
|
|
+ output_path = "temp_result_line.png"
|
|
|
+ save_plot(output_path)
|
|
|
+ return output_path
|
|
|
+
|
|
|
+def show_predict(im, pred, t_start):
|
|
|
+ """?????????????????"""
|
|
|
+ col = get_colors()
|
|
|
+ fig, ax = plt.subplots(figsize=(10, 10))
|
|
|
+ ax.imshow(im)
|
|
|
+
|
|
|
+ boxes = pred[0]['boxes'].cpu().numpy()
|
|
|
+ box_scores = pred[0]['scores'].cpu().numpy()
|
|
|
+ lines = pred[0]['line'].cpu().numpy()
|
|
|
+ line_scores = pred[0]['line_score'].cpu().numpy()
|
|
|
+ idx = 0
|
|
|
+
|
|
|
+ tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
|
|
|
+
|
|
|
+ for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
|
|
|
+ x0, y0, x1, y1 = box
|
|
|
+ if np.array_equal(line, tmp):
|
|
|
+ continue
|
|
|
+ a, b = line
|
|
|
+ if box_score >= 0.7 or line_score >= 0.9:
|
|
|
+ ax.add_patch(
|
|
|
+ plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
|
|
|
+ ax.scatter(a[1], a[0], c='#871F78', s=10)
|
|
|
+ ax.scatter(b[1], b[0], c='#871F78', s=10)
|
|
|
+ ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
|
|
|
+ idx = idx + 1
|
|
|
+ # for idx, pred in enumerate(filtered_pred):
|
|
|
+ # boxes = pred['boxes'].cpu().numpy()
|
|
|
+ # box_scores = pred['scores'].cpu().numpy()
|
|
|
+ # lines = pred['line'].cpu().numpy()
|
|
|
+ # line_scores = pred['line_score'].cpu().numpy()
|
|
|
+ # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
|
|
|
+ # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
|
|
|
+ # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
|
|
|
+ # if box_score > 0.7 and line_score > 0.9:
|
|
|
+ # x0, y0, x1, y1 = box
|
|
|
+ # a, b = line
|
|
|
+ # color = colors[(idx + box_idx) % len(colors)] # ?????????????
|
|
|
+ # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
|
|
|
+ # ax.scatter(a[1], a[0], c=color, s=10)
|
|
|
+ # ax.scatter(b[1], b[0], c=color, s=10)
|
|
|
+ # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
|
|
|
+ t_end = time.time()
|
|
|
+ print(f'show_predict used: {t_end - t_start:.2f} seconds')
|
|
|
+ output_path = "temp_result.png"
|
|
|
+ save_plot(output_path)
|
|
|
+ return output_path
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import uvicorn
|
|
|
+ uvicorn.run(app, host="0.0.0.0", port=8001)
|