aaa.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # import torch
  2. # from torchvision.utils import draw_bounding_boxes
  3. # from torchvision import transforms
  4. # import matplotlib.pyplot as plt
  5. # import numpy as np
  6. #
  7. #
  8. # def c(score):
  9. # # 根据分数返回颜色的函数,这里仅作示例,您可以根据需要修改
  10. # return (1, 0, 0) if score > 0.9 else (0, 1, 0)
  11. #
  12. #
  13. # def postprocess(lines, scores, diag_threshold, min_score, remove_overlaps):
  14. # # 假设的后处理函数,用于过滤线段
  15. # nlines = []
  16. # nscores = []
  17. # for line, score in zip(lines, scores):
  18. # if score >= min_score:
  19. # nlines.append(line)
  20. # nscores.append(score)
  21. # return np.array(nlines), np.array(nscores)
  22. #
  23. #
  24. # def show_line(img, pred, epoch, writer):
  25. # im = img.permute(1, 2, 0).cpu().numpy()
  26. #
  27. # # 绘制边界框
  28. # boxed_image = draw_bounding_boxes((img * 255).to(torch.uint8), pred[0]["boxes"],
  29. # colors="yellow", width=1).permute(1, 2, 0).cpu().numpy()
  30. #
  31. # H = pred[-1]['wires']
  32. # lines = H["lines"][0].cpu().numpy() / 128 * im.shape[:2]
  33. # scores = H["score"][0].cpu().numpy()
  34. #
  35. # print(f"Lines before deduplication: {len(lines)}")
  36. #
  37. # # 移除重复的线段
  38. # for i in range(1, len(lines)):
  39. # if (lines[i] == lines[0]).all():
  40. # lines = lines[:i]
  41. # scores = scores[:i]
  42. # break
  43. #
  44. # print(f"Lines after deduplication: {len(lines)}")
  45. #
  46. # # 后处理线段以移除重叠的线段
  47. # diag = (im.shape[0] ** 2 + im.shape[1] ** 2) ** 0.5
  48. # nlines, nscores = postprocess(lines, scores, diag * 0.01, 0, False)
  49. #
  50. # print(f"Lines after postprocessing: {len(nlines)}")
  51. #
  52. # # 创建一个新的图像并绘制线段和边界框
  53. # fig, ax = plt.subplots(figsize=(boxed_image.shape[1] / 100, boxed_image.shape[0] / 100))
  54. # ax.imshow(boxed_image)
  55. # ax.set_axis_off()
  56. # plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  57. # plt.margins(0, 0)
  58. # plt.gca().xaxis.set_major_locator(plt.NullLocator())
  59. # plt.gca().yaxis.set_major_locator(plt.NullLocator())
  60. #
  61. # PLTOPTS = {"color": "#33FFFF", "s": 15, "edgecolors": "none", "zorder": 5}
  62. # for (a, b), s in zip(nlines, nscores):
  63. # if s < 0.85: # 调整阈值以筛选显示的线段
  64. # continue
  65. # plt.plot([a[1], b[1]], [a[0], b[0]], c=c(s), linewidth=2, zorder=s)
  66. # plt.scatter(a[1], a[0], **PLTOPTS)
  67. # plt.scatter(b[1], b[0], **PLTOPTS)
  68. #
  69. # plt.tight_layout()
  70. # fig.canvas.draw()
  71. # image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(
  72. # fig.canvas.get_width_height()[::-1] + (3,))
  73. # plt.close()
  74. # img2 = transforms.ToTensor()(image_from_plot)
  75. #
  76. # writer.add_image("output_with_boxes_and_lines", img2, epoch)
  77. # print("Image with boxes and lines added to TensorBoard.")
  78. import numpy as np
  79. import matplotlib.pyplot as plt
  80. from scipy.ndimage import gaussian_filter
  81. import random
  82. # 假设我们有一些关键点位置
  83. keypoints = [(0, 0), (70, 80), (90, 30)]
  84. # 创建一个空白的热图
  85. heatmap = np.zeros((100, 100))
  86. # 将关键点位置添加到热图中
  87. for point in keypoints:
  88. y, x = point
  89. heatmap[y, x] = random.random()
  90. # heatmap[y, x] = 1 # 假设置信度为1
  91. print(heatmap)
  92. # 使用高斯滤波平滑热图
  93. heatmap_smooth = gaussian_filter(heatmap, sigma=1)
  94. print(heatmap_smooth)