aaa.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # from fastapi import FastAPI, File, UploadFile, HTTPException
  2. # from fastapi.responses import FileResponse
  3. # from fastapi.staticfiles import StaticFiles
  4. import os
  5. import torch
  6. import numpy as np
  7. from PIL import Image
  8. import skimage.io
  9. import skimage.color
  10. from torchvision import transforms
  11. import shutil
  12. import matplotlib.pyplot as plt
  13. from models.line_net.line_net import linenet_resnet50_fpn
  14. from models.wirenet.postprocess import postprocess
  15. from rtree import index
  16. import time
  17. import multiprocessing as mp
  18. # from code.ubuntu2204.VisionWeldRobotMK.multivisionmodels.models.line_net.boxline import show_box
  19. # ÉèÖÃ¶à½ø³ÌÆô¶¯·½Ê½Îª 'spawn'
  20. mp.set_start_method('spawn', force=True)
  21. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  22. def load_model(model_path):
  23. """¼ÓÔØÄ£ÐͲ¢·µ»ØÄ£ÐÍʵÀý"""
  24. model = linenet_resnet50_fpn().to(device)
  25. if os.path.exists(model_path):
  26. checkpoint = torch.load(model_path, map_location=device)
  27. model.load_state_dict(checkpoint['model_state_dict'])
  28. print(f"Loaded model from {model_path}")
  29. else:
  30. raise FileNotFoundError(f"No saved model found at {model_path}")
  31. model.eval()
  32. return model
  33. def preprocess_image(image_path):
  34. """Ô¤´¦ÀíÉÏ´«µÄͼƬ"""
  35. img = Image.open(image_path).convert("RGB")
  36. transform = transforms.ToTensor()
  37. img_tensor = transform(img)
  38. resized_img = skimage.transform.resize(
  39. img_tensor.permute(1, 2, 0).cpu().numpy().astype(np.float32), (512, 512)
  40. )
  41. return torch.tensor(resized_img).permute(2, 0, 1),img
  42. def save_plot(output_path: str):
  43. """±£´æÍ¼Ïñ²¢¹Ø±Õ»æÍ¼"""
  44. plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
  45. print(f"Saved plot to {output_path}")
  46. plt.close()
  47. def get_colors():
  48. """·µ»ØÒ»×éÔ¤¶¨ÒåµÄÑÕÉ«Áбí"""
  49. return [
  50. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  51. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  52. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  53. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  54. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  55. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  56. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  57. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  58. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  59. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  60. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  61. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  62. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  63. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  64. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  65. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  66. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  67. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  68. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  69. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  70. ]
  71. def process_box(box, lines, scores):
  72. """´¦Àíµ¥¸ö±ß½ç¿ò£¬ÕÒµ½×î¼ÑÆ¥ÅäµÄÏß¶Î"""
  73. valid_lines = [] # ´æ´¢ÓÐЧµÄÏß¶Î
  74. valid_scores = [] # ´æ´¢ÓÐЧµÄ·ÖÊý
  75. # print(f'score:{len(scores)}')
  76. for i in box:
  77. best_line = None
  78. max_length = 0.0
  79. # ±éÀúËùÓÐÏß¶Î
  80. for j in range(lines.shape[1]):
  81. # line_j = lines[0, j].cpu().numpy() / 128 * 512
  82. line_j = lines[0, j].cpu().numpy()
  83. # ¼ì²éÏß¶ÎÊÇ·ñÍêÈ«ÔÚboxÄÚ
  84. if (line_j[0][1] >= i[0] and line_j[1][1] >= i[0] and
  85. line_j[0][1] <= i[2] and line_j[1][1] <= i[2] and
  86. line_j[0][0] >= i[1] and line_j[1][0] >= i[1] and
  87. line_j[0][0] <= i[3] and line_j[1][0] <= i[3]):
  88. # ¼ÆËãÏ߶㤶È
  89. length = np.linalg.norm(line_j[0] - line_j[1])
  90. # length = scores[j].cpu().numpy()
  91. # print(length)
  92. if length > max_length:
  93. best_line = line_j
  94. max_length = length
  95. # Èç¹ûÕÒµ½ÓÐЧµÄÏ߶Σ¬ÔòÌí¼Óµ½½á¹ûÖÐ
  96. if best_line is not None:
  97. valid_lines.append(best_line)
  98. valid_scores.append(max_length) # ʹÓÃÏ߶㤶È×÷Ϊ·ÖÊý
  99. else:
  100. valid_lines.append([[0.0,0.0],[0.0,0.0]])
  101. valid_scores.append(0.0) # ʹÓÃÏß¶ÎÖÃÐŶÈ×÷Ϊ·ÖÊý
  102. # print(f'valid_lines:{valid_lines}')
  103. # print(f'valid_scores:{valid_scores}')
  104. return valid_lines, valid_scores
  105. # def box_line_optimized_parallel(imgs, pred): # ĬÈÏÖÃÐŶÈ
  106. # im = imgs.permute(1, 2, 0).cpu().numpy()
  107. # line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  108. # line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  109. #
  110. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  111. # line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  112. # for idx, box_ in enumerate(pred[0:-1]):
  113. # box = box_['boxes']
  114. #
  115. # line_ = []
  116. # score_ = []
  117. #
  118. # for i in box:
  119. # score_max = 0.0
  120. # tmp = [[0.0, 0.0], [0.0, 0.0]]
  121. #
  122. # for j in range(len(line)):
  123. # if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  124. # line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  125. # line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  126. # line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  127. #
  128. # if score[j] > score_max:
  129. # tmp = line[j]
  130. # score_max = score[j]
  131. # line_.append(tmp)
  132. # score_.append(score_max)
  133. # processed_list = torch.tensor(line_)
  134. # pred[idx]['line'] = processed_list
  135. #
  136. # processed_s_list = torch.tensor(score_)
  137. # pred[idx]['line_score'] = processed_s_list
  138. # del pred[-1]
  139. # return pred
  140. def box_line_optimized_parallel(imgs, pred, length=False): # 默认置信度
  141. im = imgs.permute(1, 2, 0).cpu().numpy()
  142. line_data = pred[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  143. line_scores = pred[-1]['wires']['score'].cpu().numpy()[0]
  144. # print(f'line_data:{line_data}')
  145. points=pred[-1]['wires']['juncs'].cpu().numpy()[0]/ 128 * 512
  146. is_all_zeros = np.all(line_data == 0.0)
  147. if is_all_zeros:
  148. for idx, box_ in enumerate(pred[0:-1]):
  149. score_max = 0.0
  150. tmp = [[0.0, 0.0], [0.0, 0.0]]
  151. processed_list = torch.tensor(tmp)
  152. pred[idx]['line'] = processed_list
  153. processed_s_list = torch.tensor(score_max)
  154. pred[idx]['line_score'] = processed_s_list
  155. else:
  156. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  157. line, score = postprocess(line_data, line_scores, diag * 0.01, 0, False)
  158. for idx, box_ in enumerate(pred[0:-2]):
  159. box = box_['boxes'] # 是一个tensor
  160. line_ = []
  161. score_ = []
  162. for i in box:
  163. score_max = 0.0
  164. tmp = [[0.0, 0.0], [0.0, 0.0]]
  165. for j in range(len(line)):
  166. if (line[j][0][1] >= i[0] and line[j][1][1] >= i[0] and
  167. line[j][0][1] <= i[2] and line[j][1][1] <= i[2] and
  168. line[j][0][0] >= i[1] and line[j][1][0] >= i[1] and
  169. line[j][0][0] <= i[3] and line[j][1][0] <= i[3]):
  170. if score[j] > score_max:
  171. tmp = line[j]
  172. score_max = score[j]
  173. # 如果 box 内无线段,则通过点坐标找最长线段
  174. if score_max == 0.0: # 说明 box 内无线段
  175. box_points = [
  176. [x, y] for x, y in points
  177. if i[0] <= y <= i[2] and i[1] <= x <= i[3]
  178. ]
  179. if len(box_points) >= 2: # 至少需要两个点才能组成线段
  180. max_distance = 0.0
  181. longest_segment = [[0.0, 0.0], [0.0, 0.0]]
  182. # 找出 box 内点组成的最长线段
  183. for p1 in box_points:
  184. for p2 in box_points:
  185. if p1 != p2:
  186. distance = np.linalg.norm(np.array(p1) - np.array(p2))
  187. if distance > max_distance:
  188. max_distance = distance
  189. longest_segment = [p1, p2]
  190. tmp = longest_segment
  191. score_max = 0.0 # 默认分数为 0.0
  192. line_.append(tmp)
  193. score_.append(score_max)
  194. processed_list = torch.tensor(line_)
  195. pred[idx]['line'] = processed_list
  196. processed_s_list = torch.tensor(score_)
  197. pred[idx]['line_score'] = processed_s_list
  198. return pred
  199. def show_predict1(imgs, pred, t_start):
  200. col = [
  201. '#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c',
  202. '#98df8a', '#d62728', '#ff9896', '#9467bd', '#c5b0d5',
  203. '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f',
  204. '#c7c7c7', '#bcbd22', '#dbdb8d', '#17becf', '#9edae5',
  205. '#8dd3c7', '#ffffb3', '#bebada', '#fb8072', '#80b1d3',
  206. '#fdb462', '#b3de69', '#fccde5', '#bc80bd', '#ccebc5',
  207. '#ffed6f', '#8da0cb', '#e78ac3', '#e5c494', '#b3b3b3',
  208. '#fdbf6f', '#ff7f00', '#cab2d6', '#637939', '#b5cf6b',
  209. '#cedb9c', '#8c6d31', '#e7969c', '#d6616b', '#7b4173',
  210. '#ad494a', '#843c39', '#dd8452', '#f7f7f7', '#cccccc',
  211. '#969696', '#525252', '#f7fcfd', '#e5f5f9', '#ccece6',
  212. '#99d8c9', '#66c2a4', '#2ca25f', '#008d4c', '#005a32',
  213. '#f7fcf0', '#e0f3db', '#ccebc5', '#a8ddb5', '#7bccc4',
  214. '#4eb3d3', '#2b8cbe', '#08589e', '#f7fcfd', '#e0ecf4',
  215. '#bfd3e6', '#9ebcda', '#8c96c6', '#8c6bb4', '#88419d',
  216. '#810f7c', '#4d004b', '#f7f7f7', '#efefef', '#d9d9d9',
  217. '#bfbfbf', '#969696', '#737373', '#525252', '#252525',
  218. '#000000', '#ffffff', '#ffeda0', '#fed976', '#feb24c',
  219. '#fd8d3c', '#fc4e2a', '#e31a1c', '#bd0026', '#800026',
  220. '#ff6f61', '#ff9e64', '#ff6347', '#ffa07a', '#fa8072'
  221. ]
  222. im = imgs.permute(1, 2, 0) # ´¦ÀíΪ [512, 512, 3]
  223. boxes = pred[0]['boxes'].cpu().numpy()
  224. box_scores = pred[0]['scores'].cpu().numpy()
  225. lines = pred[0]['line'].cpu().numpy()
  226. line_scores = pred[0]['line_score'].cpu().numpy()
  227. # ¿ÉÊÓ»¯Ô¤²â½á
  228. fig, ax = plt.subplots(figsize=(10, 10))
  229. ax.imshow(np.array(im))
  230. idx = 0
  231. tmp = np.array([[0.0, 0.0], [0.0, 0.0]])
  232. for box, line, box_score, line_score in zip(boxes, lines, box_scores, line_scores):
  233. x0, y0, x1, y1 = box
  234. # ¿òÖÐÎÞÏßµÄÌø¹ý
  235. if np.array_equal(line, tmp):
  236. continue
  237. a, b = line
  238. if box_score >= 0 or line_score >= 0:
  239. ax.add_patch(
  240. plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=col[idx], linewidth=1))
  241. ax.scatter(a[1], a[0], c='#871F78', s=10)
  242. ax.scatter(b[1], b[0], c='#871F78', s=10)
  243. ax.plot([a[1], b[1]], [a[0], b[0]], c=col[idx], linewidth=1)
  244. idx = idx + 1
  245. t_end = time.time()
  246. print(f'predict used:{t_end - t_start}')
  247. plt.savefig("temp_result.png")
  248. plt.show()
  249. # output_path = "temp_result.png"
  250. # save_plot(output_path)
  251. # return output_path
  252. def predict(image_path):
  253. start_time = time.time()
  254. model_path = r"\\192.168.50.222\share\lm\weight\20250425_112601\weights\best.pth"
  255. model = load_model(model_path)
  256. img_tensor,_ = preprocess_image(image_path)
  257. print(f'img shape:{img_tensor.shape}')
  258. # Ä£ÐÍÍÆÀí
  259. with torch.no_grad():
  260. predictions = model([img_tensor.to(device)])
  261. print(f'predictions[0]:{predictions[1][0].shape}') # 第2个是特征图 [1,256,128,128]
  262. plt.imshow(predictions[1][0][2].cpu())
  263. plt.show()
  264. # print(f'predictions[1]:{predictions[1]["wires"]["lines"]}')
  265. # lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 512 * np.array([2112, 1328])
  266. '''
  267. start_time1 = time.time()
  268. show_line(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions, start_time1)
  269. show_box(img_tensor.permute(1, 2, 0).cpu().numpy(), predictions)
  270. H = predictions[-1]['wires']
  271. lines = H["lines"][0].cpu().numpy() / 128 * np.array([2112, 1328])
  272. scores = H["score"][0].cpu().numpy()
  273. for i in range(1, len(lines)):
  274. if (lines[i] == lines[0]).all():
  275. lines = lines[:i]
  276. scores = scores[:i]
  277. break
  278. # postprocess lines to remove overlapped lines
  279. diag = (512 ** 2 + 512 ** 2) ** 0.5
  280. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  281. # lines = filtered_pred[0]['line'].cpu().numpy() / 512 * np.array([2112, 1328])
  282. print(f'线段 len:{len(nlines)}')
  283. # print(f"Initial lines shape: {lines.shape}")
  284. # print(f"Initial lines data type: {lines.dtype}")
  285. formatted_lines = []
  286. for line in nlines:
  287. if (line == [[0.0, 0.0], [0.0, 0.0]]).all():
  288. continue
  289. #line=[[500.0, 500.0], [650.0, 650.0]]
  290. start_point = np.array([line[0][0], line[0][1]])
  291. end_point = np.array([line[1][0], line[1][1]])
  292. formatted_lines.append([start_point, end_point])
  293. formatted_lines = np.array(formatted_lines)
  294. print(f"Final formatted_lines shape: {formatted_lines.shape}")
  295. print(f"Sample formatted line: {formatted_lines[0] if len(formatted_lines) > 0 else 'No lines'}")
  296. # È·±£·µ»ØµÄÊÇÈýάÊý×飺[lines_array]
  297. result = [formatted_lines]
  298. print(f"Final result type: {type(result)}")
  299. print(f"Final result[0] shape: {result[0].shape}")
  300. return result
  301. '''
  302. def show_line(im, predictions, start_time1):
  303. """»æÖÆÏ߶β¢±£´æ½á¹û"""
  304. lines = predictions[-1]['wires']['lines'][0].cpu().numpy() / 128 * 512
  305. line_scores = predictions[-1]['wires']['score'].cpu().numpy()[0]
  306. is_all_zeros = np.all(lines == 0.0)
  307. if is_all_zeros:
  308. fig, ax = plt.subplots(figsize=(10, 10))
  309. t_end = time.time()
  310. plt.savefig("temp_line.png")
  311. else:
  312. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  313. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  314. fig, ax = plt.subplots(figsize=(10, 10))
  315. ax.imshow(im)
  316. for (a, b), s in zip(nlines, nscores):
  317. if s < 0:
  318. continue
  319. ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  320. ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  321. t_end = time.time()
  322. plt.savefig("temp_line.png")
  323. print(f'show line time:{t_end-start_time1}')
  324. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  325. # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  326. # fig, ax = plt.subplots(figsize=(10, 10))
  327. # ax.imshow(im)
  328. # for (a, b), s in zip(nlines, nscores):
  329. # if s < 0:
  330. # continue
  331. # ax.scatter([a[1], b[1]], [a[0], b[0]], c='#871F78', s=2)
  332. # ax.plot([a[1], b[1]], [a[0], b[0]], c='red', linewidth=1)
  333. # t_end = time.time()
  334. # plt.savefig("temp_line.png")
  335. def show_box(im, predictions):
  336. """绘制边界框并保存结果"""
  337. boxes = predictions[0]['boxes'].cpu().numpy()
  338. box_scores = predictions[0]['scores'].cpu().numpy()
  339. colors = get_colors()
  340. fig, ax = plt.subplots(figsize=(10, 10))
  341. ax.imshow(im)
  342. for idx, (box, score) in enumerate(zip(boxes, box_scores)):
  343. if score < 0:
  344. continue
  345. x0, y0, x1, y1 = box
  346. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=colors[idx % len(colors)], linewidth=1))
  347. t_end = time.time()
  348. plt.savefig("temp_box.png")
  349. def show_predict(im, filtered_pred, t_start):
  350. """»æÖÆÆ¥ÅäºóµÄ±ß½ç¿òºÍÏ߶β¢±£´æ½á¹û"""
  351. colors = get_colors()
  352. fig, ax = plt.subplots(figsize=(10, 10))
  353. ax.imshow(im)
  354. pred = filtered_pred[0]
  355. boxes = pred['boxes'].cpu().numpy()
  356. box_scores = pred['scores'].cpu().numpy()
  357. lines = pred['line'].cpu().numpy()
  358. line_scores = pred['line_score'].cpu().numpy()
  359. print("Boxes:", pred['boxes'])
  360. print("Lines:", pred['line'])
  361. print("Line scores:", pred['line_score'])
  362. is_all_zeros = np.all(lines == 0.0)
  363. if not is_all_zeros:
  364. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  365. nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  366. for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
  367. if box_score < 0.0 or line_score < 0.0:
  368. continue
  369. # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
  370. if line is None or len(line) == 0:
  371. continue
  372. x0, y0, x1, y1 = box
  373. a, b = line
  374. color = colors[box_idx % len(colors)] # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
  375. ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  376. ax.scatter(a[1], a[0], c=color, s=10)
  377. ax.scatter(b[1], b[0], c=color, s=10)
  378. ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  379. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  380. # nlines, nscores = postprocess(lines, line_scores, diag * 0.01, 0, False)
  381. # for box_idx, (box, line, box_score, line_score) in enumerate(zip(boxes, nlines, box_scores, nscores)):
  382. # if box_score < 0.0 or line_score < 0.0:
  383. # continue
  384. #
  385. # # Èç¹ûÏß¶ÎΪ¿Õ£¨¼´Ã»ÓÐÕÒµ½ÓÐЧÏ߶Σ©£¬Ìø¹ý»æÖÆ
  386. # if line is None or len(line) == 0:
  387. # continue
  388. #
  389. # x0, y0, x1, y1 = box
  390. # a, b = line
  391. # color = colors[(idx + box_idx) % len(colors)] # ÿ¸ö±ß½ç¿ò·ÖÅäÒ»¸öΨһÑÕÉ«
  392. # ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, fill=False, edgecolor=color, linewidth=1))
  393. # ax.scatter(a[1], a[0], c=color, s=10)
  394. # ax.scatter(b[1], b[0], c=color, s=10)
  395. # ax.plot([a[1], b[1]], [a[0], b[0]], c=color, linewidth=1)
  396. t_end = time.time()
  397. # plt.show()
  398. print(f'show_predict used: {t_end - t_start:.2f} seconds')
  399. output_path = "temp_result.png"
  400. save_plot(output_path)
  401. return output_path
  402. if __name__ == "__main__":
  403. lines = predict(r"C:\Users\m2337\Desktop\p\140502.png")
  404. print(f'lines:{lines}')