main_lm_0223.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. from fastapi import FastAPI, File, UploadFile, HTTPException
  2. from fastapi.responses import FileResponse
  3. from fastapi.staticfiles import StaticFiles
  4. import os
  5. import torch
  6. import numpy as np
  7. from PIL import Image
  8. import skimage.io
  9. import skimage.color
  10. from torchvision import transforms
  11. import shutil
  12. import matplotlib.pyplot as plt
  13. from models.line_detect.line_net import linenet_resnet50_fpn
  14. from models.line_detect.postprocess import postprocess
  15. from rtree import index
  16. import time
  17. import multiprocessing as mp
  18. from fastapi.middleware.cors import CORSMiddleware
  19. # ??? FastAPI
  20. app = FastAPI()
  21. # ?? CORS ???
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=["*"], # ?????
  25. allow_credentials=True,
  26. allow_methods=["*"],
  27. allow_headers=["*"],
  28. )
  29. # ?????????? 'spawn'
  30. mp.set_start_method('spawn', force=True)
  31. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  32. def load_model(model_path):
  33. """???????????"""
  34. model = linenet_resnet50_fpn().to(device)
  35. if os.path.exists(model_path):
  36. checkpoint = torch.load(model_path, map_location=device)
  37. model.load_state_dict(checkpoint['model_state_dict'])
  38. print(f"Loaded model from {model_path}")
  39. else:
  40. raise FileNotFoundError(f"No saved model found at {model_path}")
  41. model.eval()
  42. return model
  43. def preprocess_image(image_path):
  44. """????????"""
  45. img = Image.open(image_path).convert("RGB")
  46. transform = transforms.ToTensor()
  47. img_tensor = transform(img)
  48. resized_img = skimage.transform.resize(
  49. img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
  50. )
  51. return torch.tensor(resized_img).permute(2, 0, 1)
  52. def save_plot(output_path: str):
  53. """?????????"""
  54. plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
  55. print(f"Saved plot to {output_path}")
  56. plt.close()
  57. def get_colors():
  58. """????????????"""
  59. return [
  60. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  61. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  62. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  63. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  64. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  65. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  66. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  67. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  68. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  69. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  70. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  71. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  72. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  73. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  74. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  75. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  76. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  77. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  78. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  79. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  80. ]
  81. # def process_box(box, lines, scores):
  82. # """?????????????????"""
  83. # valid_lines = [] # ???????
  84. # valid_scores = [] # ???????
  85. # for i in box:
  86. # best_line = None
  87. # max_length = 0.0
  88. # # ??????
  89. # for j in range(lines.shape[1]):
  90. # line_j = lines[0, j].cpu().numpy() / 128 * 512
  91. # # ?????????box?
  92. # if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  93. # line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  94. # line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  95. # line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  96. # # ??????
  97. # length = np.linalg.norm(line_j[0] - line_j[1])
  98. # if length > max_length:
  99. # best_line = line_j
  100. # max_length = length
  101. # # ?????????????????
  102. # if best_line is not None:
  103. # valid_lines.append(best_line)
  104. # valid_scores.append(max_length) # ??????????
  105. # return valid_lines, valid_scores
  106. def process_box(box, lines, scores):
  107. """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
  108. valid_lines = [] # ´æ´¢ÓÐЧµÄÏß¶Î
  109. valid_scores = [] # ´æ´¢ÓÐЧµÄ·ÖÊý
  110. # print(f'score:{len(scores)}')
  111. for i in box:
  112. best_line = None
  113. max_length = 0.0
  114. # ±éÀúËùÓÐÏß¶Î
  115. for j in range(lines.shape[1]):
  116. line_j = lines[0, j].cpu().numpy() / 128 * 512
  117. # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
  118. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  119. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  120. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  121. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  122. # length = np.linalg.norm(line_j[0] - line_j[1])
  123. length = scores[j].item()
  124. # print(length)
  125. if length > max_length:
  126. best_line = line_j
  127. max_length = length
  128. if best_line is not None:
  129. valid_lines.append(best_line)
  130. valid_scores.append(max_length)
  131. else:
  132. valid_lines.append([[0.0,0.0],[0.0,0.0]])
  133. valid_scores.append(0.0)
  134. # print(f'valid_lines:{valid_lines}')
  135. # print(f'valid_scores:{valid_scores}')
  136. return valid_lines, valid_scores
  137. def box_line_optimized_parallel(pred):
  138. """?????????????"""
  139. lines = pred[-1]['wires']['lines'] # ???[1, 2500, 2, 2]
  140. scores = pred[-1]['wires']['score'][0] # ?????[2500]
  141. boxes = [box_['boxes'].cpu().numpy() for box_ in pred[0:-1]] # ??box
  142. num_processes = min(mp.cpu_count(), len(boxes)) # ????????
  143. with mp.Pool(processes=num_processes) as pool:
  144. results = pool.starmap(
  145. process_box,
  146. [(box, lines, scores) for box in boxes]
  147. )
  148. # ??????
  149. filtered_pred = []
  150. for idx_box, (valid_lines, valid_scores) in enumerate(results):
  151. if valid_lines:
  152. pred[idx_box]['line'] = torch.tensor(valid_lines)
  153. pred[idx_box]['line_score'] = torch.tensor(valid_scores)
  154. filtered_pred.append(pred[idx_box])
  155. return filtered_pred
  156. @app.get("/")
  157. def read_root():
  158. """??????"""
  159. return FileResponse("static/index.html")
  160. @app.post("/predict")
  161. @app.post("/predict/")
  162. async def predict(file: UploadFile = File(...)):
  163. try:
  164. start_time = time.time()
  165. # ??????
  166. os.makedirs("uploaded_images", exist_ok=True)
  167. image_path = f"uploaded_images/{file.filename}"
  168. with open(image_path, "wb") as f:
  169. shutil.copyfileobj(file.file, f)
  170. # ????
  171. model_path = "/data/share/rlq/weights/linenet_wts/resnet50_best_e280.pth"
  172. model = load_model(model_path)
  173. # ?????
  174. img_tensor = preprocess_image(image_path)
  175. im = img_tensor.permute(1, 2, 0).cpu().numpy()
  176. # ????
  177. with torch.no_grad():
  178. predictions = model([img_tensor.to(device)])
  179. # ????????
  180. t_start = time.time()
  181. filtered_pred = box_line_optimized_parallel(predictions)
  182. t_end = time.time()
  183. print(f'Matched boxes and lines used: {t_end - t_start:.2f} seconds')
  184. # ????
  185. output_path_box = show_box(im, predictions, t_start)
  186. output_path_line = show_line(im, predictions, t_start)
  187. output_path_boxandline = show_predict(im, filtered_pred, t_start)
  188. # # ????
  189. # combined_image_path = "combined_result.png"
  190. # combine_images(
  191. # [output_path_boxandline],
  192. # # [output_path_boxandline, output_path_box, output_path_line],
  193. # titles=["Box and Line"],
  194. # output_path=combined_image_path
  195. # )
  196. end_time = time.time()
  197. print(f'Total time: {end_time - start_time:.2f} seconds')
  198. # ???????
  199. return FileResponse(output_path_boxandline, media_type="image/png", filename="result.png")
  200. except Exception as e:
  201. raise HTTPException(status_code=500, detail=str(e))
  202. def combine_images(image_paths: list, titles: list, output_path: str):
  203. """????????????"""
  204. fig, axes = plt.subplots(1, len(image_paths), figsize=(15, 5))
  205. for ax, img_path, title in zip(axes, image_paths, titles):
  206. ax.imshow(plt.imread(img_path))
  207. ax.set_title(title)
  208. ax.axis("off")
  209. plt.savefig(output_path, bbox_inches="tight", pad_inches=0)
  210. plt.close()
  211. def show_box(im, predictions, t_start):
  212. """??????????"""
  213. boxes = predictions[0]['boxes'].cpu().numpy()
  214. box_scores = predictions[0]['scores'].cpu().numpy()
  215. colors = get_colors()
  216. fig, ax = plt.subplots(figsize=(10, 10))
  217. ax.imshow(im)
  218. for idx, (box, score) in enumerate(zip(boxes, box_scores)):
  219. if score < 0.7:
  220. continue
  221. x0, y0, x1, y1 = box
  222. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
  223. t_end = time.time()
  224. print(f'show_box used: {t_end - t_start:.2f} seconds')
  225. output_path = "temp_result_box.png"
  226. save_plot(output_path)
  227. return output_path
  228. def show_line(im, predictions, t_start):
  229. """?????????"""
  230. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  231. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  232. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  233. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  234. fig, ax = plt.subplots(figsize=(10, 10))
  235. ax.imshow(im)
  236. for (a, b), s in zip(nlines, nscores):
  237. if s < 0.9:
  238. continue
  239. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  240. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  241. t_end = time.time()
  242. print(f'show_line used: {t_end - t_start:.2f} seconds')
  243. output_path = "temp_result_line.png"
  244. save_plot(output_path)
  245. return output_path
  246. def show_predict(im, pred, t_start):
  247. """?????????????????"""
  248. col = get_colors()
  249. fig, ax = plt.subplots(figsize=(10, 10))
  250. ax.imshow(im)
  251. boxes = pred[0]['boxes'].cpu().numpy()
  252. box_scores = pred[0]['scores'].cpu().numpy()
  253. lines = pred[0]['line'].cpu().numpy()
  254. line_scores = pred[0]['line_score'].cpu().numpy()
  255. idx = 0
  256. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  257. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  258. x0, y0, x1, y1 = box
  259. if np.array_equal(line, tmp):
  260. continue
  261. a, b = line
  262. if box_score >= 0.7 or line_score >= 0.9:
  263. ax.add_patch(
  264. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  265. ax.scatter(a[1], a[0], c='#871F78', s=10)
  266. ax.scatter(b[1], b[0], c='#871F78', s=10)
  267. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  268. idx = idx + 1
  269. # for idx, pred in enumerate(filtered_pred):
  270. # boxes = pred['boxes'].cpu().numpy()
  271. # box_scores = pred['scores'].cpu().numpy()
  272. # lines = pred['line'].cpu().numpy()
  273. # line_scores = pred['line_score'].cpu().numpy()
  274. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  275. # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  276. # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
  277. # if box_score > 0.7 and line_score > 0.9:
  278. # x0, y0, x1, y1 = box
  279. # a, b = line
  280. # color = colors[(idx + box_idx) % len(colors)] # ?????????????
  281. # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  282. # ax.scatter(a[1], a[0], c=color, s=10)
  283. # ax.scatter(b[1], b[0], c=color, s=10)
  284. # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  285. t_end = time.time()
  286. print(f'show_predict used: {t_end - t_start:.2f} seconds')
  287. output_path = "temp_result.png"
  288. save_plot(output_path)
  289. return output_path
  290. if __name__ == "__main__":
  291. import uvicorn
  292. uvicorn.run(app, host="0.0.0.0", port=8001)