aaa.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. from torchvision.utils import draw_bounding_boxes
  3. from torchvision import transforms
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. def c(score):
  7. # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
  8. return (1, 0, 0) if score > 0.9 else (0, 1, 0)
  9. def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
  10. # 假设的后处理函数,用于过滤线段
  11. nlines = []
  12. nscores = []
  13. for line, score in zip(lines, scores):
  14. if score >= min_score:
  15. nlines.append(line)
  16. nscores.append(score)
  17. return np.array(nlines), np.array(nscores)
  18. def show_line(img, pred, epoch, writer):
  19. im = img.permute(1, 2, 0).cpu().numpy()
  20. # 绘制边界框
  21. boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  22. colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
  23. H = pred[-1]['wires']
  24. lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  25. scores = H["score"][0].cpu().numpy()
  26. print(f"Lines before deduplication: {len(lines)}")
  27. # 移除重复的线段
  28. for i in range(1, len(lines)):
  29. if (lines[i] == lines[0]).all():
  30. lines = lines[:i]
  31. scores = scores[:i]
  32. break
  33. print(f"Lines after deduplication: {len(lines)}")
  34. # 后处理线段以移除重叠的线段
  35. diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  36. nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  37. print(f"Lines after postprocessing: {len(nlines)}")
  38. # 创建一个新的图像并绘制线段和边界框
  39. fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
  40. ax.imshow(boxed_image)
  41. ax.set_axis_off()
  42. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  43. plt.margins(0, 0)
  44. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  45. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  46. PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  47. for (a, b), s in zip(nlines, nscores):
  48. if s < 0.85: # 调整阈值以筛选显示的线段
  49. continue
  50. plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  51. plt.scatter(a[1], a[0], **PLTOPTS)
  52. plt.scatter(b[1], b[0], **PLTOPTS)
  53. plt.tight_layout()
  54. fig.canvas.draw()
  55. image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  56. fig.canvas.get_width_height()[::-1] + (3,))
  57. plt.close()
  58. img2 = transforms.ToTensor()(image_from_plot)
  59. writer.add_image("output_with_boxes_and_lines", img2, epoch)
  60. print("Image with boxes and lines added to TensorBoard.")